Merge branch 'feat-auth' into feat-profiles
This commit is contained in:
@@ -1539,18 +1539,22 @@ CREATE TABLE community_members(
|
|||||||
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
-- Twitch Auth {{
|
-- Twitch User Auth {{{
|
||||||
CREATE TABLE twitch_tokens(
|
CREATE TABLE twitch_user_auth(
|
||||||
userid BIGINT PRIMARY KEY,
|
userid TEXT PRIMARY KEY,
|
||||||
access_token TEXT NOT NULL,
|
access_token TEXT NOT NULL,
|
||||||
expires_at TIMESTAMPTZ NOT NULL,
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
refresh_token TEXT NOT NULL,
|
refresh_token TEXT NOT NULL,
|
||||||
obtained_at TIMESTAMPTZ
|
obtained_at TIMESTAMPTZ
|
||||||
);
|
);
|
||||||
-- }}
|
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
-- Analytics Data {{{
|
-- Analytics Data {{{
|
||||||
|
|||||||
@@ -11,3 +11,6 @@ pillow
|
|||||||
python-dateutil
|
python-dateutil
|
||||||
bidict
|
bidict
|
||||||
frozendict
|
frozendict
|
||||||
|
TwitchIO
|
||||||
|
websockets
|
||||||
|
TwitchAPI
|
||||||
|
|||||||
26
src/bot.py
26
src/bot.py
@@ -80,6 +80,14 @@ async def main():
|
|||||||
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
|
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
crocbot = CrocBot(
|
||||||
|
config=conf,
|
||||||
|
data=db,
|
||||||
|
prefix='!',
|
||||||
|
initial_channels=conf.croccy.getlist('initial_channels'),
|
||||||
|
token=conf.croccy['token'],
|
||||||
|
)
|
||||||
|
|
||||||
lionbot = await stack.enter_async_context(
|
lionbot = await stack.enter_async_context(
|
||||||
LionBot(
|
LionBot(
|
||||||
command_prefix='!',
|
command_prefix='!',
|
||||||
@@ -90,6 +98,7 @@ async def main():
|
|||||||
config=conf,
|
config=conf,
|
||||||
initial_extensions=[
|
initial_extensions=[
|
||||||
'utils', 'core', 'analytics',
|
'utils', 'core', 'analytics',
|
||||||
|
'twitch',
|
||||||
'modules',
|
'modules',
|
||||||
'babel',
|
'babel',
|
||||||
'tracking.voice', 'tracking.text',
|
'tracking.voice', 'tracking.text',
|
||||||
@@ -104,26 +113,15 @@ async def main():
|
|||||||
translator=translator,
|
translator=translator,
|
||||||
chunk_guilds_at_startup=False,
|
chunk_guilds_at_startup=False,
|
||||||
system_monitor=system_monitor,
|
system_monitor=system_monitor,
|
||||||
|
crocbot=crocbot,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
crocbot = CrocBot(
|
|
||||||
config=conf,
|
|
||||||
data=db,
|
|
||||||
prefix='!',
|
|
||||||
initial_channels=conf.croccy.getlist('initial_channels'),
|
|
||||||
token=conf.croccy['token'],
|
|
||||||
lionbot=lionbot
|
|
||||||
)
|
|
||||||
lionbot.crocbot = crocbot
|
|
||||||
|
|
||||||
crocbot.load_module('modules')
|
|
||||||
|
|
||||||
crocstart = asyncio.create_task(start_croccy(crocbot))
|
crocstart = asyncio.create_task(start_croccy(crocbot))
|
||||||
lionstart = asyncio.create_task(start_lion(lionbot))
|
lionstart = asyncio.create_task(start_lion(lionbot))
|
||||||
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
|
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
|
||||||
crocstart.cancel()
|
# crocstart.cancel()
|
||||||
lionstart.cancel()
|
# lionstart.cancel()
|
||||||
|
|
||||||
async def start_lion(lionbot):
|
async def start_lion(lionbot):
|
||||||
ctx_bot.set(lionbot)
|
ctx_bot.set(lionbot)
|
||||||
|
|||||||
@@ -10,10 +10,6 @@ from data import Database
|
|||||||
from .config import Conf
|
from .config import Conf
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .LionBot import LionBot
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -21,12 +17,11 @@ class CrocBot(commands.Bot):
|
|||||||
def __init__(self, *args,
|
def __init__(self, *args,
|
||||||
config: Conf,
|
config: Conf,
|
||||||
data: Database,
|
data: Database,
|
||||||
lionbot: 'LionBot', **kwargs):
|
**kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.data = data
|
self.data = data
|
||||||
self.pubsub = pubsub.PubSubPool(self)
|
self.pubsub = pubsub.PubSubPool(self)
|
||||||
self.lionbot = lionbot
|
|
||||||
|
|
||||||
async def event_ready(self):
|
async def event_ready(self):
|
||||||
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .errors import HandledException, SafeCancellation
|
|||||||
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
|
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from meta.CrocBot import CrocBot
|
||||||
from core.cog import CoreCog
|
from core.cog import CoreCog
|
||||||
from core.config import ConfigCog
|
from core.config import ConfigCog
|
||||||
from tracking.voice.cog import VoiceTrackerCog
|
from tracking.voice.cog import VoiceTrackerCog
|
||||||
@@ -58,6 +59,7 @@ class LionBot(Bot):
|
|||||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||||
testing_guilds: List[int] = [],
|
testing_guilds: List[int] = [],
|
||||||
system_monitor: Optional[SystemMonitor] = None,
|
system_monitor: Optional[SystemMonitor] = None,
|
||||||
|
crocbot: Optional['CrocBot'] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
kwargs.setdefault('tree_cls', LionTree)
|
kwargs.setdefault('tree_cls', LionTree)
|
||||||
@@ -73,6 +75,8 @@ class LionBot(Bot):
|
|||||||
self.app_ipc = app_ipc
|
self.app_ipc = app_ipc
|
||||||
self.translator = translator
|
self.translator = translator
|
||||||
|
|
||||||
|
self.crocbot = crocbot
|
||||||
|
|
||||||
self.system_monitor = system_monitor or SystemMonitor()
|
self.system_monitor = system_monitor or SystemMonitor()
|
||||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||||
self.system_monitor.add_component(self.monitor)
|
self.system_monitor.add_component(self.monitor)
|
||||||
|
|||||||
@@ -1,23 +1,36 @@
|
|||||||
from typing import Any
|
from functools import partial
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from discord.ext.commands import Cog
|
from discord.ext.commands import Cog
|
||||||
from discord.ext import commands as cmds
|
from discord.ext import commands as cmds
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio.ext.commands import Command, Bot
|
||||||
|
from twitchio.ext.commands.meta import CogEvent
|
||||||
|
|
||||||
|
|
||||||
class LionCog(Cog):
|
class LionCog(Cog):
|
||||||
# A set of other cogs that this cog depends on
|
# A set of other cogs that this cog depends on
|
||||||
depends_on: set['LionCog'] = set()
|
depends_on: set['LionCog'] = set()
|
||||||
_placeholder_groups_: set[str]
|
_placeholder_groups_: set[str]
|
||||||
|
_twitch_cmds_: dict[str, Command]
|
||||||
|
_twitch_events_: dict[str, CogEvent]
|
||||||
|
_twitch_events_loaded_: set[Callable]
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
cls._placeholder_groups_ = set()
|
cls._placeholder_groups_ = set()
|
||||||
|
cls._twitch_cmds_ = {}
|
||||||
|
cls._twitch_events_ = {}
|
||||||
|
|
||||||
for base in reversed(cls.__mro__):
|
for base in reversed(cls.__mro__):
|
||||||
for elem, value in base.__dict__.items():
|
for elem, value in base.__dict__.items():
|
||||||
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
||||||
cls._placeholder_groups_.add(value.name)
|
cls._placeholder_groups_.add(value.name)
|
||||||
|
elif isinstance(value, Command):
|
||||||
|
cls._twitch_cmds_[value.name] = value
|
||||||
|
elif isinstance(value, CogEvent):
|
||||||
|
cls._twitch_events_[value.name] = value
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any):
|
def __new__(cls, *args: Any, **kwargs: Any):
|
||||||
# Patch to ensure no placeholder groups are in the command list
|
# Patch to ensure no placeholder groups are in the command list
|
||||||
@@ -34,6 +47,51 @@ class LionCog(Cog):
|
|||||||
|
|
||||||
return await super()._inject(bot, *args, *kwargs)
|
return await super()._inject(bot, *args, *kwargs)
|
||||||
|
|
||||||
|
def _load_twitch_methods(self, bot: Bot):
|
||||||
|
for name, command in self._twitch_cmds_.items():
|
||||||
|
command._instance = self
|
||||||
|
command.cog = self
|
||||||
|
bot.add_command(command)
|
||||||
|
|
||||||
|
for name, event in self._twitch_events_.items():
|
||||||
|
callback = partial(event, self)
|
||||||
|
self._twitch_events_loaded_.add(callback)
|
||||||
|
bot.add_event(callback=callback, name=name)
|
||||||
|
|
||||||
|
def _unload_twitch_methods(self, bot: Bot):
|
||||||
|
for name in self._twitch_cmds_:
|
||||||
|
bot.remove_command(name)
|
||||||
|
|
||||||
|
for callback in self._twitch_events_loaded_:
|
||||||
|
bot.remove_event(callback=callback)
|
||||||
|
|
||||||
|
self._twitch_events_loaded_.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def twitch_event(cls, event: Optional[str] = None):
|
||||||
|
def decorator(func) -> CogEvent:
|
||||||
|
event_name = event or func.__name__
|
||||||
|
return CogEvent(name=event_name, func=func, module=cls.__module__)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
async def cog_check(self, ctx): # type: ignore
|
||||||
|
"""
|
||||||
|
TwitchIO assumes cog_check is a coroutine,
|
||||||
|
so here we narrow the check to only a coroutine.
|
||||||
|
|
||||||
|
The ctx maybe either be a twitch command context or a dpy context.
|
||||||
|
"""
|
||||||
|
if isinstance(ctx, cmds.Context):
|
||||||
|
return await self.cog_check_discord(ctx)
|
||||||
|
if isinstance(ctx, commands.Context):
|
||||||
|
return await self.cog_check_twitch(ctx)
|
||||||
|
|
||||||
|
async def cog_check_discord(self, ctx: cmds.Context):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def cog_check_twitch(self, ctx: commands.Context):
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def placeholder_group(cls, group: cmds.HybridGroup):
|
def placeholder_group(cls, group: cmds.HybridGroup):
|
||||||
group._placeholder_group_ = True
|
group._placeholder_group_ = True
|
||||||
|
|||||||
@@ -26,20 +26,12 @@ active_discord = [
|
|||||||
'.premium',
|
'.premium',
|
||||||
'.streamalerts',
|
'.streamalerts',
|
||||||
'.test',
|
'.test',
|
||||||
]
|
'.counters',
|
||||||
|
|
||||||
active_twitch = [
|
|
||||||
'.nowdoing',
|
'.nowdoing',
|
||||||
'.shoutouts',
|
'.shoutouts',
|
||||||
'.counters',
|
|
||||||
'.tagstrings',
|
'.tagstrings',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def prepare(bot):
|
|
||||||
for ext in active_twitch:
|
|
||||||
bot.load_module(this_package + ext)
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
for ext in active_discord:
|
for ext in active_discord:
|
||||||
await bot.load_extension(ext, package=this_package)
|
await bot.load_extension(ext, package=this_package)
|
||||||
|
|||||||
@@ -4,10 +4,5 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from .cog import CounterCog
|
from .cog import CounterCog
|
||||||
|
|
||||||
def prepare(bot):
|
|
||||||
bot.add_cog(CounterCog(bot))
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
from .lion_cog import CounterCog
|
|
||||||
|
|
||||||
await bot.add_cog(CounterCog(bot))
|
await bot.add_cog(CounterCog(bot))
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ from enum import Enum
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands as cmds
|
||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
|
|
||||||
from data.queries import ORDER
|
from data.queries import ORDER
|
||||||
from meta import CrocBot
|
from meta import LionCog, LionBot, CrocBot
|
||||||
from utils.lib import utc_now
|
from utils.lib import utc_now
|
||||||
from . import logger
|
from . import logger
|
||||||
from .data import CounterData
|
from .data import CounterData
|
||||||
@@ -22,10 +25,11 @@ class PERIOD(Enum):
|
|||||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||||
|
|
||||||
|
|
||||||
class CounterCog(commands.Cog):
|
class CounterCog(LionCog):
|
||||||
def __init__(self, bot: CrocBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = bot.data.load_registry(CounterData())
|
self.crocbot: CrocBot = bot.crocbot
|
||||||
|
self.data = bot.db.load_registry(CounterData())
|
||||||
|
|
||||||
self.loaded = asyncio.Event()
|
self.loaded = asyncio.Event()
|
||||||
|
|
||||||
@@ -33,9 +37,18 @@ class CounterCog(commands.Cog):
|
|||||||
self.counters = {}
|
self.counters = {}
|
||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
|
self._load_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
|
await self.load_counters()
|
||||||
self.loaded.set()
|
self.loaded.set()
|
||||||
|
|
||||||
|
async def cog_unload(self):
|
||||||
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
async def cog_check(self, ctx):
|
||||||
|
return True
|
||||||
|
|
||||||
async def load_counters(self):
|
async def load_counters(self):
|
||||||
"""
|
"""
|
||||||
Initialise counter name cache.
|
Initialise counter name cache.
|
||||||
@@ -46,18 +59,6 @@ class CounterCog(commands.Cog):
|
|||||||
f"Loaded {len(self.counters)} counters."
|
f"Loaded {len(self.counters)} counters."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ensure_loaded(self):
|
|
||||||
if not self.loaded.is_set():
|
|
||||||
await self.cog_load()
|
|
||||||
|
|
||||||
@commands.Cog.event('event_ready') # type: ignore
|
|
||||||
async def on_ready(self):
|
|
||||||
await self.ensure_loaded()
|
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
|
||||||
await self.ensure_loaded()
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Counters API
|
# Counters API
|
||||||
|
|
||||||
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
||||||
@@ -171,7 +172,7 @@ class CounterCog(commands.Cog):
|
|||||||
if period is PERIOD.ALL:
|
if period is PERIOD.ALL:
|
||||||
start_time = None
|
start_time = None
|
||||||
elif period is PERIOD.STREAM:
|
elif period is PERIOD.STREAM:
|
||||||
streams = await self.bot.fetch_streams(user_ids=[userid])
|
streams = await self.crocbot.fetch_streams(user_ids=[userid])
|
||||||
if streams:
|
if streams:
|
||||||
stream = streams[0]
|
stream = streams[0]
|
||||||
start_time = stream.started_at
|
start_time = stream.started_at
|
||||||
@@ -199,7 +200,7 @@ class CounterCog(commands.Cog):
|
|||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
lb = await self.leaderboard(counter, start_time=start_time)
|
||||||
if lb:
|
if lb:
|
||||||
userids = list(lb.keys())
|
userids = list(lb.keys())
|
||||||
users = await self.bot.fetch_users(ids=userids)
|
users = await self.crocbot.fetch_users(ids=userids)
|
||||||
name_map = {user.id: user.display_name for user in users}
|
name_map = {user.id: user.display_name for user in users}
|
||||||
parts = []
|
parts = []
|
||||||
for userid, total in lb.items():
|
for userid, total in lb.items():
|
||||||
@@ -283,17 +284,9 @@ class CounterCog(commands.Cog):
|
|||||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def reload(self, ctx: commands.Context, *, args: str = ''):
|
async def stuff(self, ctx: commands.Context, *, args: str = ''):
|
||||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
await ctx.reply(f"Stuff {args}")
|
||||||
return
|
|
||||||
if not args:
|
|
||||||
await ctx.reply("Full reload not implemented yet.")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.bot.reload_module(args)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to reload")
|
|
||||||
await ctx.reply("Failed to reload module! Check console~")
|
|
||||||
else:
|
|
||||||
await ctx.reply("Reloaded!")
|
|
||||||
|
|
||||||
|
@cmds.hybrid_command('water')
|
||||||
|
async def d_water_cmd(self, ctx):
|
||||||
|
await ctx.reply(repr(ctx))
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.ext import commands as cmds
|
|
||||||
from discord import app_commands as appcmds
|
|
||||||
|
|
||||||
from meta import LionBot, LionCog, LionContext
|
|
||||||
from meta.errors import UserInputError
|
|
||||||
from meta.logger import log_wrap
|
|
||||||
from utils.lib import utc_now
|
|
||||||
from data.conditions import NULL
|
|
||||||
|
|
||||||
from . import logger
|
|
||||||
from .data import CounterData
|
|
||||||
|
|
||||||
|
|
||||||
class CounterCog(LionCog):
|
|
||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
self.counter_cog = bot.crocbot.get_cog('CounterCog')
|
|
||||||
@@ -4,6 +4,5 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from .cog import NowDoingCog
|
from .cog import NowDoingCog
|
||||||
|
|
||||||
def prepare(bot):
|
async def setup(bot):
|
||||||
logger.info("Preparing the nowdoing module.")
|
await bot.add_cog(NowDoingCog(bot))
|
||||||
bot.add_cog(NowDoingCog(bot))
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from attr import dataclass
|
|||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
from meta import CrocBot
|
from meta import CrocBot, LionCog
|
||||||
|
from meta.LionBot import LionBot
|
||||||
from meta.sockets import Channel, register_channel
|
from meta.sockets import Channel, register_channel
|
||||||
from utils.lib import strfdelta, utc_now
|
from utils.lib import strfdelta, utc_now
|
||||||
from . import logger
|
from . import logger
|
||||||
@@ -78,10 +79,11 @@ class NowDoingChannel(Channel):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
class NowDoingCog(commands.Cog):
|
class NowDoingCog(LionCog):
|
||||||
def __init__(self, bot: CrocBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = bot.data.load_registry(NowListData())
|
self.crocbot = bot.crocbot
|
||||||
|
self.data = bot.db.load_registry(NowListData())
|
||||||
self.channel = NowDoingChannel(self)
|
self.channel = NowDoingChannel(self)
|
||||||
register_channel(self.channel.name, self.channel)
|
register_channel(self.channel.name, self.channel)
|
||||||
|
|
||||||
@@ -94,21 +96,19 @@ class NowDoingCog(commands.Cog):
|
|||||||
await self.data.init()
|
await self.data.init()
|
||||||
|
|
||||||
await self.load_tasks()
|
await self.load_tasks()
|
||||||
|
|
||||||
|
self._load_twitch_methods(self.crocbot)
|
||||||
self.loaded.set()
|
self.loaded.set()
|
||||||
|
|
||||||
async def ensure_loaded(self):
|
async def cog_unload(self):
|
||||||
"""
|
self.loaded.clear()
|
||||||
Hack because lib devs decided to remove async cog loading.
|
self.tasks.clear()
|
||||||
"""
|
self._unload_twitch_methods(self.crocbot)
|
||||||
if not self.loaded.is_set():
|
|
||||||
await self.cog_load()
|
|
||||||
|
|
||||||
@commands.Cog.event('event_ready') # type: ignore
|
|
||||||
async def on_ready(self):
|
|
||||||
await self.ensure_loaded()
|
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
await self.ensure_loaded()
|
if not self.loaded.is_set():
|
||||||
|
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def load_tasks(self):
|
async def load_tasks(self):
|
||||||
@@ -130,6 +130,7 @@ class NowDoingCog(commands.Cog):
|
|||||||
@commands.command(aliases=['task', 'check'])
|
@commands.command(aliases=['task', 'check'])
|
||||||
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
userid = int(ctx.author.id)
|
userid = int(ctx.author.id)
|
||||||
|
args = args.strip() if args else None
|
||||||
if args:
|
if args:
|
||||||
await self.data.Task.table.delete_where(userid=userid)
|
await self.data.Task.table.delete_where(userid=userid)
|
||||||
task = await self.data.Task.create(
|
task = await self.data.Task.create(
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from .cog import ShoutoutCog
|
from .cog import ShoutoutCog
|
||||||
|
|
||||||
def prepare(bot):
|
async def setup(bot):
|
||||||
bot.add_cog(ShoutoutCog(bot))
|
await bot.add_cog(ShoutoutCog(bot))
|
||||||
|
|||||||
@@ -4,50 +4,50 @@ from typing import Optional
|
|||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
from meta import CrocBot
|
from meta import CrocBot, LionBot, LionCog
|
||||||
from utils.lib import replace_multiple
|
from utils.lib import replace_multiple
|
||||||
from . import logger
|
from . import logger
|
||||||
from .data import ShoutoutData
|
from .data import ShoutoutData
|
||||||
|
|
||||||
|
|
||||||
class ShoutoutCog(commands.Cog):
|
class ShoutoutCog(LionCog):
|
||||||
# Future extension: channel defaults and config
|
# Future extension: channel defaults and config
|
||||||
DEFAULT_SHOUTOUT = """
|
DEFAULT_SHOUTOUT = """
|
||||||
We think that {name} is a great streamer and you should check them out \
|
We think that {name} is a great streamer and you should check them out \
|
||||||
and drop a follow! \
|
and drop a follow! \
|
||||||
They {areorwere} streaming {game} at {channel}
|
They {areorwere} streaming {game} at {channel}
|
||||||
"""
|
"""
|
||||||
def __init__(self, bot: CrocBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = bot.data.load_registry(ShoutoutData())
|
self.crocbot = bot.crocbot
|
||||||
|
self.data = bot.db.load_registry(ShoutoutData())
|
||||||
|
|
||||||
self.loaded = asyncio.Event()
|
self.loaded = asyncio.Event()
|
||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
|
self._load_twitch_methods(self.crocbot)
|
||||||
self.loaded.set()
|
self.loaded.set()
|
||||||
|
|
||||||
async def ensure_loaded(self):
|
async def cog_unload(self):
|
||||||
if not self.loaded.is_set():
|
self.loaded.clear()
|
||||||
await self.cog_load()
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
@commands.Cog.event('event_ready') # type: ignore
|
|
||||||
async def on_ready(self):
|
|
||||||
await self.ensure_loaded()
|
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
await self.ensure_loaded()
|
if not self.loaded.is_set():
|
||||||
|
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def format_shoutout(self, text: str, user: twitchio.User):
|
async def format_shoutout(self, text: str, user: twitchio.User):
|
||||||
channels = await self.bot.fetch_channels([user.id])
|
channels = await self.crocbot.fetch_channels([user.id])
|
||||||
if channels:
|
if channels:
|
||||||
channel = channels[0]
|
channel = channels[0]
|
||||||
game = channel.game_name or 'Unknown'
|
game = channel.game_name or 'Unknown'
|
||||||
else:
|
else:
|
||||||
game = 'Unknown'
|
game = 'Unknown'
|
||||||
|
|
||||||
streams = await self.bot.fetch_streams([user.id])
|
streams = await self.crocbot.fetch_streams([user.id])
|
||||||
live = bool(streams)
|
live = bool(streams)
|
||||||
|
|
||||||
mapping = {
|
mapping = {
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from .cog import TagCog
|
from .cog import TagCog
|
||||||
|
|
||||||
def prepare(bot):
|
async def setup(bot):
|
||||||
bot.add_cog(TagCog(bot))
|
await bot.add_cog(TagCog(bot))
|
||||||
|
|||||||
@@ -6,16 +6,17 @@ import difflib
|
|||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
from meta import CrocBot
|
from meta import CrocBot, LionBot, LionCog
|
||||||
from utils.lib import utc_now
|
from utils.lib import utc_now
|
||||||
from . import logger
|
from . import logger
|
||||||
from .data import TagData
|
from .data import TagData
|
||||||
|
|
||||||
|
|
||||||
class TagCog(commands.Cog):
|
class TagCog(LionCog):
|
||||||
def __init__(self, bot: CrocBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = bot.data.load_registry(TagData())
|
self.crocbot = bot.crocbot
|
||||||
|
self.data = bot.db.load_registry(TagData())
|
||||||
|
|
||||||
self.loaded = asyncio.Event()
|
self.loaded = asyncio.Event()
|
||||||
|
|
||||||
@@ -31,19 +32,24 @@ class TagCog(commands.Cog):
|
|||||||
|
|
||||||
self.tags.clear()
|
self.tags.clear()
|
||||||
self.tags.update(tags)
|
self.tags.update(tags)
|
||||||
|
logger.info(f"Loaded {len(tags)} into cache.")
|
||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
await self.load_tags()
|
await self.load_tags()
|
||||||
|
self._load_twitch_methods(self.crocbot)
|
||||||
self.loaded.set()
|
self.loaded.set()
|
||||||
|
|
||||||
async def ensure_loaded(self):
|
async def cog_unload(self):
|
||||||
if not self.loaded.is_set():
|
self.loaded.clear()
|
||||||
await self.cog_load()
|
self.tags.clear()
|
||||||
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
@commands.Cog.event('event_ready')
|
async def cog_check(self, ctx):
|
||||||
async def on_ready(self):
|
if not self.loaded.is_set():
|
||||||
await self.ensure_loaded()
|
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
# API
|
# API
|
||||||
|
|
||||||
|
|||||||
9
src/twitch/__init__.py
Normal file
9
src/twitch/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from .cog import TwitchAuthCog
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
await bot.add_cog(TwitchAuthCog(bot))
|
||||||
|
|
||||||
50
src/twitch/authclient.py
Normal file
50
src/twitch/authclient.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Testing client for the twitch AuthServer.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
|
||||||
|
from meta.config import conf
|
||||||
|
|
||||||
|
|
||||||
|
URI = "http://localhost:3000/twiauth/confirm"
|
||||||
|
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Load in client id and secret
|
||||||
|
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
|
||||||
|
url = auth.return_auth_url()
|
||||||
|
|
||||||
|
# Post url to user
|
||||||
|
print(url)
|
||||||
|
|
||||||
|
# Send listen request to server
|
||||||
|
# Wait for listen request
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
|
||||||
|
await ws.send_json({'state': auth.state})
|
||||||
|
result = await ws.receive_json()
|
||||||
|
|
||||||
|
# Hopefully get back code, print the response
|
||||||
|
print(f"Recieved: {result}")
|
||||||
|
|
||||||
|
# Authorise with code and client details
|
||||||
|
tokens = await auth.authenticate(user_token=result['code'])
|
||||||
|
if tokens:
|
||||||
|
token, refresh = tokens
|
||||||
|
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
|
||||||
|
print(f"Authorised!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
86
src/twitch/authserver.py
Normal file
86
src/twitch/authserver.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
|
||||||
|
|
||||||
|
|
||||||
|
class AuthServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.listeners = {}
|
||||||
|
|
||||||
|
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
args = request.query
|
||||||
|
if 'state' not in args:
|
||||||
|
raise web.HTTPBadRequest(text="No state provided.")
|
||||||
|
if args['state'] not in self.listeners:
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state.")
|
||||||
|
self.listeners[args['state']].set_result(dict(args))
|
||||||
|
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
|
||||||
|
|
||||||
|
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
_reqid = str(uuid.uuid1())
|
||||||
|
reqid.set(_reqid)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
|
||||||
|
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
# Get the listen request data
|
||||||
|
try:
|
||||||
|
listen_req = await ws.receive_json(timeout=60)
|
||||||
|
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
|
||||||
|
if 'state' not in listen_req:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Listen request must include state string.")
|
||||||
|
elif listen_req['state'] in self.listeners:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state string.")
|
||||||
|
except ValueError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except TypeError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
|
||||||
|
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
|
||||||
|
raise web.HTTPInternalServerError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fut = self.listeners[listen_req['state']] = asyncio.Future()
|
||||||
|
result = await asyncio.wait_for(fut, timeout=120)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
|
||||||
|
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
|
||||||
|
finally:
|
||||||
|
self.listeners.pop(listen_req['state'], None)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
|
||||||
|
await ws.send_json(result)
|
||||||
|
await ws.close()
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
app = web.Application()
|
||||||
|
server = AuthServer()
|
||||||
|
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
|
||||||
|
app.router.add_get("/twiauth/listen", server.handle_listen_request)
|
||||||
|
|
||||||
|
logger.info("App setup and configured. Starting now.")
|
||||||
|
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
main(sys.argv)
|
||||||
84
src/twitch/cog.py
Normal file
84
src/twitch/cog.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.twitch import AuthType, Twitch
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
from data.queries import ORDER
|
||||||
|
from meta import LionCog, LionBot, CrocBot
|
||||||
|
from meta.LionContext import LionContext
|
||||||
|
from twitch.userflow import UserAuthFlow
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from . import logger
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthCog(LionCog):
|
||||||
|
DEFAULT_SCOPES = []
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = bot.db.load_registry(TwitchAuthData())
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
# ----- Auth API -----
|
||||||
|
|
||||||
|
async def fetch_client_for(self, userid: int):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the given userid is authorised.
|
||||||
|
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||||
|
"""
|
||||||
|
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||||
|
if authrow:
|
||||||
|
if scopes:
|
||||||
|
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
has_auth = set(map(str, scopes)).issubset(has_scopes)
|
||||||
|
else:
|
||||||
|
has_auth = True
|
||||||
|
else:
|
||||||
|
has_auth = False
|
||||||
|
return has_auth
|
||||||
|
|
||||||
|
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||||
|
"""
|
||||||
|
Start the user authentication flow for the given userid.
|
||||||
|
Will request the given scopes along with the default ones and any existing scopes.
|
||||||
|
"""
|
||||||
|
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
existing = map(AuthScope, existing_strs)
|
||||||
|
to_request = set(existing).union(scopes)
|
||||||
|
return await self.start_auth(to_request)
|
||||||
|
|
||||||
|
async def start_auth(self, scopes = []):
|
||||||
|
# TODO: Work out a way to just clone the current twitch object
|
||||||
|
# Or can we otherwise build UserAuthenticator without app auth?
|
||||||
|
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||||
|
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||||
|
await flow.setup()
|
||||||
|
|
||||||
|
return flow
|
||||||
|
|
||||||
|
# ----- Commands -----
|
||||||
|
@cmds.hybrid_command(name='auth')
|
||||||
|
async def cmd_auth(self, ctx: LionContext):
|
||||||
|
if ctx.interaction:
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
flow = await self.start_auth()
|
||||||
|
await ctx.reply(flow.auth.return_auth_url())
|
||||||
|
await flow.run()
|
||||||
|
await ctx.reply("Authentication Complete!")
|
||||||
79
src/twitch/data.py
Normal file
79
src/twitch/data.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthData(Registry):
|
||||||
|
class UserAuthRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
obtained_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'twitch_user_auth'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = Integer(primary=True)
|
||||||
|
access_token = String()
|
||||||
|
refresh_token = String()
|
||||||
|
expires_at = Timestamp()
|
||||||
|
obtained_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_user_auth(
|
||||||
|
cls, userid: str, token: str, refresh: str,
|
||||||
|
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||||
|
scopes: list[str]
|
||||||
|
):
|
||||||
|
if cls._connector is None:
|
||||||
|
raise ValueError("Attempting to use uninitialised Registry.")
|
||||||
|
async with cls._connector.connection() as conn:
|
||||||
|
cls._connector.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
# Clear row for this userid
|
||||||
|
await cls.table.delete_where(userid=userid)
|
||||||
|
|
||||||
|
# Insert new user row
|
||||||
|
row = await cls.create(
|
||||||
|
userid=userid,
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh,
|
||||||
|
expires_at=expires_at,
|
||||||
|
obtained_at=obtained_at
|
||||||
|
)
|
||||||
|
# Insert new scope rows
|
||||||
|
if scopes:
|
||||||
|
await TwitchAuthData.user_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in scopes)
|
||||||
|
)
|
||||||
|
return row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get a list of scopes stored for the given user.
|
||||||
|
Will return an empty list if the user is not authenticated.
|
||||||
|
"""
|
||||||
|
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||||
|
|
||||||
|
return [row.scope for row in rows] if rows else []
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
"""
|
||||||
|
user_scopes = Table('twitch_token_scopes')
|
||||||
0
src/twitch/lib.py
Normal file
0
src/twitch/lib.py
Normal file
88
src/twitch/userflow.py
Normal file
88
src/twitch/userflow.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator, validate_token
|
||||||
|
from twitchAPI.type import AuthType
|
||||||
|
from twitchio.client import asyncio
|
||||||
|
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
from . import logger
|
||||||
|
|
||||||
|
class UserAuthFlow:
|
||||||
|
auth: UserAuthenticator
|
||||||
|
data: TwitchAuthData
|
||||||
|
auth_ws: str
|
||||||
|
|
||||||
|
def __init__(self, data, auth, auth_ws):
|
||||||
|
self.auth = auth
|
||||||
|
self.data = data
|
||||||
|
self.auth_ws = auth_ws
|
||||||
|
|
||||||
|
self._setup_done = asyncio.Event()
|
||||||
|
self._comm_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""
|
||||||
|
Establishes websocket connection to the AuthServer,
|
||||||
|
and requests listening for the given state.
|
||||||
|
Propagates any exceptions that occur during connection setup.
|
||||||
|
"""
|
||||||
|
if self._setup_done.is_set():
|
||||||
|
raise ValueError("UserAuthFlow is already set up.")
|
||||||
|
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
|
||||||
|
await self._setup_done.wait()
|
||||||
|
if self._comm_task.done() and (exc := self._comm_task.exception()):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def _communicate(self):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect(self.auth_ws) as ws:
|
||||||
|
await ws.send_json({'state': self.auth.state})
|
||||||
|
self._setup_done.set()
|
||||||
|
return await ws.receive_json()
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
if not self._setup_done.is_set():
|
||||||
|
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||||
|
if self._comm_task is None:
|
||||||
|
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||||
|
|
||||||
|
result = await self._comm_task
|
||||||
|
if result.get('error', None):
|
||||||
|
# TODO Custom auth errors
|
||||||
|
# This is only documented to occure when the user denies the auth
|
||||||
|
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||||
|
|
||||||
|
if result.get('state', None) != self.auth.state:
|
||||||
|
# This should never happen unless the authserver has its wires crossed somehow,
|
||||||
|
# or the connection has been tampered with.
|
||||||
|
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||||
|
logger.critical(
|
||||||
|
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
|
||||||
|
)
|
||||||
|
raise SafeCancellation(
|
||||||
|
"Could not complete authentication! Invalid server response."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now assume result has a valid code
|
||||||
|
# Exchange code for an auth token and a refresh token
|
||||||
|
# Ignore type here, authenticate returns None if a callback function has been given.
|
||||||
|
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
|
||||||
|
|
||||||
|
# Fetch the associated userid and basic info
|
||||||
|
v_result = await validate_token(token)
|
||||||
|
userid = v_result['user_id']
|
||||||
|
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||||
|
|
||||||
|
# Save auth data
|
||||||
|
return await self.data.UserAuthRow.update_user_auth(
|
||||||
|
userid=userid, token=token, refresh=refresh,
|
||||||
|
expires_at=expiry, obtained_at=utc_now(),
|
||||||
|
scopes=[scope.value for scope in self.auth.scopes]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user