Compare commits

...

2 Commits

Author SHA1 Message Date
8421c5359d (WIP) Add user profile module. 2025-06-06 00:05:41 +10:00
2cf81c38e8 Add twitch auth module. 2025-06-06 00:05:24 +10:00
15 changed files with 1313 additions and 4 deletions

View File

@@ -4,3 +4,5 @@ discord.py [voice]
iso8601
psycopg[pool]
pytz
twitchio
twitchAPI

View File

@@ -4,6 +4,7 @@ import logging
import aiohttp
import discord
from discord.ext import commands
from twitchAPI.twitch import Twitch
from meta import LionBot, conf, sharding, appname
from meta.app import shardname
@@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus:
async def main():
log_action_stack.set(("Initialising",))
logger.info("Initialising StudyLion")
logger.info("Initialising LionBot")
intents = discord.Intents.all()
intents.members = True
intents.message_content = True
intents.presences = False
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
@@ -82,6 +85,7 @@ async def main():
help_command=None,
proxy=conf.bot.get('proxy', None),
chunk_guilds_at_startup=False,
twitch=twitch
) as lionbot:
ctx_bot.set(lionbot)
lionbot.system_monitor.add_component(
@@ -89,11 +93,11 @@ async def main():
)
try:
log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
logger.info("LionBot initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError:
log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
def _main():

View File

@@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession
from twitchAPI.twitch import Twitch
from data import Database
from utils.lib import tabulate
@@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat
if TYPE_CHECKING:
from core.cog import CoreCog
from twitch.cog import TwitchAuthCog
from modules.profiles.cog import ProfileCog
logger = logging.getLogger(__name__)
@@ -31,7 +34,9 @@ class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession,
testing_guilds: List[int] = [], **kwargs
twitch: Twitch,
testing_guilds: List[int] = [],
**kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
@@ -43,6 +48,7 @@ class LionBot(Bot):
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.twitch = twitch
self.system_monitor = SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
@@ -101,6 +107,14 @@ class LionBot(Bot):
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
...
@overload
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
...
@overload
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
...
@overload
def get_cog(self, name: str) -> Optional[Cog]:
...

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import ProfileCog
async def setup(bot):
await bot.add_cog(ProfileCog(bot))

436
src/modules/profiles/cog.py Normal file
View File

@@ -0,0 +1,436 @@
import asyncio
from enum import Enum
from typing import Optional, overload
from datetime import timedelta
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from twitchAPI.helper import first
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from twitchio import User
from twitchAPI.object.api import TwitchUser
from data.queries import ORDER
from meta import LionCog, LionBot, LionContext
from meta.logger import log_wrap
from utils.lib import utc_now
from . import logger
from .data import ProfileData
from .profile import UserProfile
from .community import Community
class ProfileCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(ProfileData())
self._profile_migrators = {}
self._comm_migrators = {}
async def cog_load(self):
await self.data.init()
async def cog_check(self, ctx):
return True
# Profile API
def add_profile_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._profile_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added user profile migrator {name}: {migrator}"
)
return migrator
def del_profile_migrator(self, name: str):
migrator = self._profile_migrators.pop(name, None)
logger.info(
f"Removed user profile migrator {name}: {migrator}"
)
@log_wrap(action="profile migration")
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
logger.info(
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._profile_migrators.items():
try:
result = await migrator(source_profile, target_profile)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running user profile migrator {name} "
f"migrating {source_profile!r} to {target_profile!r}."
)
raise
# Move all Discord and Twitch profile references over to the new profile
discord_rows = await self.data.DiscordProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
# And then mark the old profile as migrated
await source_profile.profile_row.update(migrated=target_profile.profileid)
results.append("Marking old profile as migrated.. finished!")
return results
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
"""
Fetch a UserProfile by the given id.
"""
return await UserProfile.fetch(self.bot, profile_id=profile_id)
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided discord account.
"""
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_discord(self.bot, user)
return profile
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided twitch account.
"""
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_twitch(self.bot, user)
return profile
# Community API
def add_community_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._comm_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added community migrator {name}: {migrator}"
)
return migrator
def del_community_migrator(self, name: str):
migrator = self._comm_migrators.pop(name, None)
logger.info(
f"Removed community migrator {name}: {migrator}"
)
@log_wrap(action="community migration")
async def migrate_community(self, source_comm, target_comm) -> list[str]:
logger.info(
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._comm_migrators.items():
try:
result = await migrator(source_comm, target_comm)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running community migrator {name} "
f"migrating {source_comm!r} to {target_comm!r}."
)
raise
# Move all Discord and Twitch community preferences over to the new profile
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
profileid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
communityid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
# And then mark the old community as migrated
await source_comm.update(migrated=target_comm.communityid)
results.append("Marking old community as migrated.. finished!")
return results
async def fetch_community_by_id(self, community_id: int) -> Community:
"""
Fetch a Community by the given id.
"""
return await Community.fetch(self.bot, community_id=community_id)
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
"""
Fetch or create a Community from the provided discord guild.
"""
comm = await Community.fetch_from_discordid(self.bot, guild.id)
if comm is None:
comm = await Community.create_from_discord(self.bot, guild)
return comm
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
"""
Fetch or create a Community from the provided twitch account.
"""
community = await Community.fetch_from_twitchid(self.bot, user.id)
if community is None:
community = await Community.create_from_twitch(self.bot, user)
return community
# ----- Profile Commands -----
@cmds.hybrid_group(
name='profiles',
description="Base comand group for user profiles."
)
async def profiles_grp(self, ctx: LionContext):
...
@profiles_grp.group(
name='link',
description="Base command group for linking profiles"
)
async def profiles_link_grp(self, ctx: LionContext):
...
@profiles_link_grp.command(
name='twitch',
description="Link a twitch account to your current profile."
)
async def profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
await ctx.interaction.response.defer(ephemeral=True)
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth()
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
"to Twitch."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning profile merge..."
)
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
# if not results:
# logger.error(
# f"User {authrow} obtained from Twitch authentication does not exist."
# )
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
# return
# user = results[0]
try:
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
except Exception:
logger.error(
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
exc_info=True
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
if user is None:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
# Retrieve author's profile if it exists
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
# Check if the twitch-side user has a profile
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if author_profile and source_profile is None:
# All we need to do is attach the twitch row
await author_profile.attach_twitch(user.id)
await message.edit(
content=f"Successfully added Twitch account **{user.display_name}**! There was no profile data to merge."
)
elif source_profile and author_profile is None:
# Attach the discord row to the profile
await source_profile.attach_discord(ctx.author.id)
await message.edit(
content=f"Successfully connected to Twitch profile **{user.display_name}**! There was no profile data to merge."
)
elif source_profile is None and author_profile is None:
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
await profile.attach_twitch(user.id)
await message.edit(
content=f"Opened a new user profile for you and linked Twitch account **{user.display_name}**."
)
elif author_profile.profileid == source_profile.profileid:
await message.edit(
content=f"The Twitch account **{user.display_name}** is already linked to your profile!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_profile(source_profile, author_profile)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your account profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging profiles...",
*results,
"**Successfully linked account and merged profile data!**"
))
await message.edit(content=content)
# ----- Community Commands -----
@cmds.hybrid_group(
name='community',
description="Base comand group for community profiles."
)
async def community_grp(self, ctx: LionContext):
...
@community_grp.group(
name='link',
description="Base command group for linking communities"
)
async def community_link_grp(self, ctx: LionContext):
...
@community_link_grp.command(
name='twitch',
description="Link a twitch account to this community."
)
@appcmds.guild_only()
@appcmds.default_permissions(manage_guild=True)
async def comm_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
assert ctx.guild is not None
await ctx.interaction.response.defer(ephemeral=True)
if not ctx.author.guild_permissions.manage_guild:
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
return
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth(
scopes=[
AuthScope.CHAT_EDIT,
AuthScope.CHAT_READ,
AuthScope.MODERATION_READ,
AuthScope.CHANNEL_BOT,
]
)
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning community profile merge..."
)
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
# if not results:
# logger.error(
# f"User {authrow} obtained from Twitch authentication does not exist."
# )
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
# return
# user = results[0]
try:
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
except Exception:
logger.error(
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
exc_info=True
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
if user is None:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
# Retrieve author's profile if it exists
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
# Check if the twitch-side user has a profile
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
if guild_comm and twitch_comm is None:
# All we need to do is attach the twitch row
await guild_comm.attach_twitch(user.id)
await message.edit(
content=f"Successfully linked Twitch channel **{user.display_name}**! There was no community data to merge."
)
elif twitch_comm and guild_comm is None:
# Attach the discord row to the profile
await twitch_comm.attach_discord(ctx.guild.id)
await message.edit(
content=f"Successfully connected to Twitch channel **{user.display_name}**!"
)
elif twitch_comm is None and guild_comm is None:
profile = await Community.create_from_discord(self.bot, ctx.guild)
await profile.attach_twitch(user.id)
await message.edit(
content=f"Created a new community for this server and linked Twitch account **{user.display_name}**."
)
elif guild_comm.communityid == twitch_comm.communityid:
await message.edit(
content=f"This server is already linked to the Twitch channel **{user.display_name}**!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_community(twitch_comm, guild_comm)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your community profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging community profiles...",
*results,
"**Successfully linked account and merged community data!**"
))
await message.edit(content=content)

View File

@@ -0,0 +1,123 @@
from typing import Optional, Self
import discord
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class Community:
def __init__(self, bot: LionBot, community_row):
self.bot = bot
self.row: ProfileData.CommunityRow = community_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def communityid(self):
return self.row.communityid
def __repr__(self):
return f"<Community communityid={self.communityid} row={self.row}>"
async def attach_discord(self, guildid: int):
"""
Attach a new discord guild to this community.
Assumes the discord guild is not already associated to a community.
"""
discord_row = await self.data.DiscordCommunityRow.create(
communityid=self.communityid,
guildid=guildid
)
logger.info(
f"Attached discord guild {guildid} to community {self!r}"
)
return discord_row
async def attach_twitch(self, channelid: str):
"""
Attach a new Twitch user channel to this community.
"""
twitch_row = await self.data.TwitchCommunityRow.create(
communityid=self.communityid,
channelid=str(channelid)
)
logger.info(
f"Attached twitch channel {channelid} to community {self!r}"
)
return twitch_row
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
"""
Fetch the Discord guild rows associated to this community.
"""
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
"""
Fetch the Twitch user rows associated to this profile.
"""
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
@classmethod
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
if community_row is None:
raise ValueError("Provided community_id does not exist.")
return cls(bot, community_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty community with the given initial arguments.
Communities should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial preferences (e.g. name, avatar).
"""
# Create a new community
data = bot.get_cog('ProfileCog').data
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
return await cls.fetch(bot, row.communityid)
@classmethod
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
"""
Create a new community using the given Discord guild as a base.
"""
self = await cls.create(bot, **kwargs)
await self.attach_discord(guild.id)
return self
@classmethod
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
"""
Create a new profile using the given Twitch channel user as a base.
The provided `user` must have an `id` attribute.
"""
self = await cls.create(bot, **kwargs)
await self.attach_twitch(str(user.id))
return self

View File

@@ -0,0 +1,158 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class ProfileData(Registry):
class UserProfileRow(RowModel):
"""
Schema
------
CREATE TABLE user_profiles(
profileid SERIAL PRIMARY KEY,
nickname TEXT,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_profiles'
_cache_ = {}
profileid = Integer(primary=True)
nickname = String()
migrated = Integer()
created_at = Timestamp()
class DiscordProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_discord(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
"""
_tablename_ = 'profiles_discord'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = Integer()
created_at = Integer()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class TwitchProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_twitch(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
"""
_tablename_ = 'profiles_twitch'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = String()
created_at = Timestamp()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class CommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities(
communityid SERIAL PRIMARY KEY,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities'
_cache_ = {}
communityid = Integer(primary=True)
migrated = Integer()
created_at = Timestamp()
class DiscordCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_discord(
guildid BIGINT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_discord'
_cache_ = {}
guildid = Integer(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class TwitchCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_twitch(
channelid TEXT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_twitch'
_cache_ = {}
channelid = String(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class CommunityMemberRow(RowModel):
"""
Schema
------
CREATE TABLE community_members(
memberid SERIAL PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
"""
_tablename_ = 'community_members'
_cache_ = {}
memberid = Integer(primary=True)
communityid = Integer()
profileid = Integer()
created_at = Timestamp()

View File

@@ -0,0 +1,138 @@
from typing import Optional, Self
import discord
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class UserProfile:
def __init__(self, bot: LionBot, profile_row):
self.bot = bot
self.profile_row: ProfileData.UserProfileRow = profile_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def profileid(self):
return self.profile_row.profileid
def __repr__(self):
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
async def get_name(self) -> Optional[str]:
return self.profile_row.nickname
async def attach_discord(self, userid: int):
"""
Attach a new discord user to this profile.
Assumes the discord user does not itself have a profile.
"""
discord_row = await self.data.DiscordProfileRow.create(
profileid=self.profileid,
userid=userid
)
logger.info(
f"Attached discord user {userid} to profile {self!r}"
)
return discord_row
async def attach_twitch(self, userid: str):
"""
Attach a new Twitch user to this profile.
"""
twitch_row = await self.data.TwitchProfileRow.create(
profileid=self.profileid,
userid=userid
)
logger.info(
f"Attached twitch user {userid} to profile {self!r}"
)
return twitch_row
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
"""
Fetch the Discord accounts associated to this profile.
"""
return await self.data.DiscordProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
"""
Fetch the Twitch accounts associated to this profile.
"""
return await self.data.TwitchProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
@classmethod
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
if profile_row is None:
raise ValueError("Provided profile_id does not exist.")
return cls(bot, profile_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty profile with the given initial arguments.
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial profile preferences (e.g. name, avatar).
"""
# Create a new profile
data = bot.get_cog('ProfileCog').data
profile_row = await data.UserProfileRow.create(created_at=utc_now())
profile = await cls.fetch(bot, profile_row.profileid)
return profile
@classmethod
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
"""
Create a new profile using the given Discord user as a base.
"""
kwargs.setdefault('nickname', user.name)
profile = await cls.create(bot, **kwargs)
await profile.attach_discord(user.id)
return profile
@classmethod
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
"""
Create a new profile using the given Twitch user as a base.
Assumes the provided `user` has `id` and `name` attributes.
"""
kwargs.setdefault('nickname', user.name)
profile = await cls.create(bot, **kwargs)
await profile.attach_twitch(str(user.id))
return profile

9
src/twitch/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TwitchAuthCog
async def setup(bot):
await bot.add_cog(TwitchAuthCog(bot))

50
src/twitch/authclient.py Normal file
View File

@@ -0,0 +1,50 @@
"""
Testing client for the twitch AuthServer.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
import asyncio
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.type import AuthScope
from meta.config import conf
URI = "http://localhost:3000/twiauth/confirm"
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
async def main():
# Load in client id and secret
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
url = auth.return_auth_url()
# Post url to user
print(url)
# Send listen request to server
# Wait for listen request
async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
await ws.send_json({'state': auth.state})
result = await ws.receive_json()
# Hopefully get back code, print the response
print(f"Recieved: {result}")
# Authorise with code and client details
tokens = await auth.authenticate(user_token=result['code'])
if tokens:
token, refresh = tokens
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
print(f"Authorised!")
if __name__ == '__main__':
asyncio.run(main())

86
src/twitch/authserver.py Normal file
View File

@@ -0,0 +1,86 @@
import logging
import uuid
import asyncio
from contextvars import ContextVar
import aiohttp
from aiohttp import web
logger = logging.getLogger(__name__)
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
class AuthServer:
def __init__(self):
self.listeners = {}
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
args = request.query
if 'state' not in args:
raise web.HTTPBadRequest(text="No state provided.")
if args['state'] not in self.listeners:
raise web.HTTPBadRequest(text="Invalid state.")
self.listeners[args['state']].set_result(dict(args))
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
_reqid = str(uuid.uuid1())
reqid.set(_reqid)
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get the listen request data
try:
listen_req = await ws.receive_json(timeout=60)
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
if 'state' not in listen_req:
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
raise web.HTTPBadRequest(text="Listen request must include state string.")
elif listen_req['state'] in self.listeners:
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
raise web.HTTPBadRequest(text="Invalid state string.")
except ValueError:
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except TypeError:
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
except Exception:
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
raise web.HTTPInternalServerError()
try:
fut = self.listeners[listen_req['state']] = asyncio.Future()
result = await asyncio.wait_for(fut, timeout=120)
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
finally:
self.listeners.pop(listen_req['state'], None)
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
await ws.send_json(result)
await ws.close()
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
return ws
def main(argv):
app = web.Application()
server = AuthServer()
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
app.router.add_get("/twiauth/listen", server.handle_listen_request)
logger.info("App setup and configured. Starting now.")
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
if __name__ == '__main__':
import sys
main(sys.argv)

114
src/twitch/cog.py Normal file
View File

@@ -0,0 +1,114 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.twitch import AuthType, Twitch
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot
from meta.LionContext import LionContext
from twitch.userflow import UserAuthFlow
from utils.lib import utc_now
from . import logger
from .data import TwitchAuthData
class TwitchAuthCog(LionCog):
DEFAULT_SCOPES = []
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
self.client_cache = {}
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: str):
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow is None:
# TODO: Some user authentication error
self.client_cache.pop(userid, None)
raise ValueError("Requested user is not authenticated.")
if (twitch := self.client_cache.get(userid)) is None:
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
authscopes = [AuthScope(scope) for scope in scopes]
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
self.client_cache[userid] = twitch
return twitch
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
Checks whether the given userid is authorised.
If 'scopes' is given, will also check the user has all of the given scopes.
"""
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
desired = {scope.value for scope in scopes}
has_auth = desired.issubset(has_scopes)
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
else:
has_auth = True
else:
has_auth = False
return has_auth
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
"""
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
self.client_cache.pop(userid, None)
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
return await self.start_auth(to_request)
async def start_auth(self, scopes = []):
# TODO: Work out a way to just clone the current twitch object
# Or can we otherwise build UserAuthenticator without app auth?
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
await flow.setup()
return flow
# ----- Commands -----
@cmds.hybrid_command(name='auth')
async def cmd_auth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
flow = await self.start_auth()
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")
@cmds.hybrid_command(name='modauth')
async def cmd_modauth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
scopes = [
AuthScope.MODERATOR_READ_FOLLOWERS,
AuthScope.CHANNEL_READ_REDEMPTIONS,
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
]
flow = await self.start_auth(scopes=scopes)
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

79
src/twitch/data.py Normal file
View File

@@ -0,0 +1,79 @@
import datetime as dt
from data import Registry, RowModel, Table
from data.columns import Integer, String, Timestamp
class TwitchAuthData(Registry):
class UserAuthRow(RowModel):
"""
Schema
------
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
"""
_tablename_ = 'twitch_user_auth'
_cache_ = {}
userid = Integer(primary=True)
access_token = String()
refresh_token = String()
expires_at = Timestamp()
obtained_at = Timestamp()
@classmethod
async def update_user_auth(
cls, userid: str, token: str, refresh: str,
expires_at: dt.datetime, obtained_at: dt.datetime,
scopes: list[str]
):
if cls._connector is None:
raise ValueError("Attempting to use uninitialised Registry.")
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Clear row for this userid
await cls.table.delete_where(userid=userid)
# Insert new user row
row = await cls.create(
userid=userid,
access_token=token,
refresh_token=refresh,
expires_at=expires_at,
obtained_at=obtained_at
)
# Insert new scope rows
if scopes:
await TwitchAuthData.user_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in scopes)
)
return row
@classmethod
async def get_scopes_for(cls, userid: str) -> list[str]:
"""
Get a list of scopes stored for the given user.
Will return an empty list if the user is not authenticated.
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row['scope'] for row in rows] if rows else []
"""
Schema
------
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
"""
user_scopes = Table('twitch_user_scopes')

0
src/twitch/lib.py Normal file
View File

88
src/twitch/userflow.py Normal file
View File

@@ -0,0 +1,88 @@
from typing import Optional
import datetime as dt
from aiohttp import web
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator, validate_token
from twitchAPI.type import AuthType
from twitchio.client import asyncio
from meta.errors import SafeCancellation
from utils.lib import utc_now
from .data import TwitchAuthData
from . import logger
class UserAuthFlow:
auth: UserAuthenticator
data: TwitchAuthData
auth_ws: str
def __init__(self, data, auth, auth_ws):
self.auth = auth
self.data = data
self.auth_ws = auth_ws
self._setup_done = asyncio.Event()
self._comm_task: Optional[asyncio.Task] = None
async def setup(self):
"""
Establishes websocket connection to the AuthServer,
and requests listening for the given state.
Propagates any exceptions that occur during connection setup.
"""
if self._setup_done.is_set():
raise ValueError("UserAuthFlow is already set up.")
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
await self._setup_done.wait()
if self._comm_task.done() and (exc := self._comm_task.exception()):
raise exc
async def _communicate(self):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.auth_ws) as ws:
await ws.send_json({'state': self.auth.state})
self._setup_done.set()
return await ws.receive_json()
async def run(self) -> TwitchAuthData.UserAuthRow:
if not self._setup_done.is_set():
raise ValueError("Cannot run UserAuthFlow before setup.")
if self._comm_task is None:
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
result = await self._comm_task
if result.get('error', None):
# TODO Custom auth errors
# This is only documented to occur when the user denies the auth
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
if result.get('state', None) != self.auth.state:
# This should never happen unless the authserver has its wires crossed somehow,
# or the connection has been tampered with.
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
logger.critical(
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
)
raise SafeCancellation(
"Could not complete authentication! Invalid server response."
)
# Now assume result has a valid code
# Exchange code for an auth token and a refresh token
# Ignore type here, authenticate returns None if a callback function has been given.
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
# Fetch the associated userid and basic info
v_result = await validate_token(token)
userid = v_result['user_id']
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
# Save auth data
return await self.data.UserAuthRow.update_user_auth(
userid=userid, token=token, refresh=refresh,
expires_at=expiry, obtained_at=utc_now(),
scopes=[scope.value for scope in self.auth.scopes]
)