Compare commits
2 Commits
9625dec1e4
...
8421c5359d
| Author | SHA1 | Date | |
|---|---|---|---|
| 8421c5359d | |||
| 2cf81c38e8 |
@@ -4,3 +4,5 @@ discord.py [voice]
|
|||||||
iso8601
|
iso8601
|
||||||
psycopg[pool]
|
psycopg[pool]
|
||||||
pytz
|
pytz
|
||||||
|
twitchio
|
||||||
|
twitchAPI
|
||||||
|
|||||||
10
src/bot.py
10
src/bot.py
@@ -4,6 +4,7 @@ import logging
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
|
||||||
from meta import LionBot, conf, sharding, appname
|
from meta import LionBot, conf, sharding, appname
|
||||||
from meta.app import shardname
|
from meta.app import shardname
|
||||||
@@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus:
|
|||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
log_action_stack.set(("Initialising",))
|
log_action_stack.set(("Initialising",))
|
||||||
logger.info("Initialising StudyLion")
|
logger.info("Initialising LionBot")
|
||||||
|
|
||||||
intents = discord.Intents.all()
|
intents = discord.Intents.all()
|
||||||
intents.members = True
|
intents.members = True
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
intents.presences = False
|
intents.presences = False
|
||||||
|
|
||||||
|
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||||
|
|
||||||
async with db.open():
|
async with db.open():
|
||||||
version = await db.version()
|
version = await db.version()
|
||||||
if version.version != DATA_VERSION:
|
if version.version != DATA_VERSION:
|
||||||
@@ -82,6 +85,7 @@ async def main():
|
|||||||
help_command=None,
|
help_command=None,
|
||||||
proxy=conf.bot.get('proxy', None),
|
proxy=conf.bot.get('proxy', None),
|
||||||
chunk_guilds_at_startup=False,
|
chunk_guilds_at_startup=False,
|
||||||
|
twitch=twitch
|
||||||
) as lionbot:
|
) as lionbot:
|
||||||
ctx_bot.set(lionbot)
|
ctx_bot.set(lionbot)
|
||||||
lionbot.system_monitor.add_component(
|
lionbot.system_monitor.add_component(
|
||||||
@@ -89,11 +93,11 @@ async def main():
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
log_context.set(f"APP: {appname}")
|
log_context.set(f"APP: {appname}")
|
||||||
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
logger.info("LionBot initialised, starting!", extra={'action': 'Starting'})
|
||||||
await lionbot.start(conf.bot['TOKEN'])
|
await lionbot.start(conf.bot['TOKEN'])
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
log_context.set(f"APP: {appname}")
|
log_context.set(f"APP: {appname}")
|
||||||
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _main():
|
def _main():
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
|||||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
|
||||||
from data import Database
|
from data import Database
|
||||||
from utils.lib import tabulate
|
from utils.lib import tabulate
|
||||||
@@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.cog import CoreCog
|
from core.cog import CoreCog
|
||||||
|
from twitch.cog import TwitchAuthCog
|
||||||
|
from modules.profiles.cog import ProfileCog
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,7 +34,9 @@ class LionBot(Bot):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||||
initial_extensions: List[str], web_client: ClientSession,
|
initial_extensions: List[str], web_client: ClientSession,
|
||||||
testing_guilds: List[int] = [], **kwargs
|
twitch: Twitch,
|
||||||
|
testing_guilds: List[int] = [],
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
kwargs.setdefault('tree_cls', LionTree)
|
kwargs.setdefault('tree_cls', LionTree)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -43,6 +48,7 @@ class LionBot(Bot):
|
|||||||
self.shardname = shardname
|
self.shardname = shardname
|
||||||
# self.appdata = appdata
|
# self.appdata = appdata
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.twitch = twitch
|
||||||
|
|
||||||
self.system_monitor = SystemMonitor()
|
self.system_monitor = SystemMonitor()
|
||||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||||
@@ -101,6 +107,14 @@ class LionBot(Bot):
|
|||||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_cog(self, name: str) -> Optional[Cog]:
|
def get_cog(self, name: str) -> Optional[Cog]:
|
||||||
...
|
...
|
||||||
|
|||||||
8
src/modules/profiles/__init__.py
Normal file
8
src/modules/profiles/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from .cog import ProfileCog
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
await bot.add_cog(ProfileCog(bot))
|
||||||
436
src/modules/profiles/cog.py
Normal file
436
src/modules/profiles/cog.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, overload
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import app_commands as appcmds
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
from twitchAPI.helper import first
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio import User
|
||||||
|
from twitchAPI.object.api import TwitchUser
|
||||||
|
|
||||||
|
|
||||||
|
from data.queries import ORDER
|
||||||
|
from meta import LionCog, LionBot, LionContext
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
from .profile import UserProfile
|
||||||
|
from .community import Community
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileCog(LionCog):
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
self.data = bot.db.load_registry(ProfileData())
|
||||||
|
|
||||||
|
self._profile_migrators = {}
|
||||||
|
self._comm_migrators = {}
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
async def cog_check(self, ctx):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Profile API
|
||||||
|
def add_profile_migrator(self, migrator, name=None):
|
||||||
|
name = name or migrator.__name__
|
||||||
|
self._profile_migrators[name or migrator.__name__] = migrator
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Added user profile migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
return migrator
|
||||||
|
|
||||||
|
def del_profile_migrator(self, name: str):
|
||||||
|
migrator = self._profile_migrators.pop(name, None)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Removed user profile migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_wrap(action="profile migration")
|
||||||
|
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
|
||||||
|
logger.info(
|
||||||
|
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
# Wrap this in a transaction so if something goes wrong with migration,
|
||||||
|
# we roll back safely (although this may mess up caches)
|
||||||
|
async with self.bot.db.connection() as conn:
|
||||||
|
self.bot.db.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
for name, migrator in self._profile_migrators.items():
|
||||||
|
try:
|
||||||
|
result = await migrator(source_profile, target_profile)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception running user profile migrator {name} "
|
||||||
|
f"migrating {source_profile!r} to {target_profile!r}."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Move all Discord and Twitch profile references over to the new profile
|
||||||
|
discord_rows = await self.data.DiscordProfileRow.table.update_where(
|
||||||
|
profileid=source_profile.profileid
|
||||||
|
).set(profileid=target_profile.profileid)
|
||||||
|
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
|
||||||
|
|
||||||
|
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
|
||||||
|
profileid=source_profile.profileid
|
||||||
|
).set(profileid=target_profile.profileid)
|
||||||
|
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
|
||||||
|
|
||||||
|
# And then mark the old profile as migrated
|
||||||
|
await source_profile.profile_row.update(migrated=target_profile.profileid)
|
||||||
|
results.append("Marking old profile as migrated.. finished!")
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch a UserProfile by the given id.
|
||||||
|
"""
|
||||||
|
return await UserProfile.fetch(self.bot, profile_id=profile_id)
|
||||||
|
|
||||||
|
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch or create a UserProfile from the provided discord account.
|
||||||
|
"""
|
||||||
|
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
|
||||||
|
if profile is None:
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, user)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch or create a UserProfile from the provided twitch account.
|
||||||
|
"""
|
||||||
|
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
if profile is None:
|
||||||
|
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
# Community API
|
||||||
|
def add_community_migrator(self, migrator, name=None):
|
||||||
|
name = name or migrator.__name__
|
||||||
|
self._comm_migrators[name or migrator.__name__] = migrator
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Added community migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
return migrator
|
||||||
|
|
||||||
|
def del_community_migrator(self, name: str):
|
||||||
|
migrator = self._comm_migrators.pop(name, None)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Removed community migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_wrap(action="community migration")
|
||||||
|
async def migrate_community(self, source_comm, target_comm) -> list[str]:
|
||||||
|
logger.info(
|
||||||
|
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
# Wrap this in a transaction so if something goes wrong with migration,
|
||||||
|
# we roll back safely (although this may mess up caches)
|
||||||
|
async with self.bot.db.connection() as conn:
|
||||||
|
self.bot.db.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
for name, migrator in self._comm_migrators.items():
|
||||||
|
try:
|
||||||
|
result = await migrator(source_comm, target_comm)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception running community migrator {name} "
|
||||||
|
f"migrating {source_comm!r} to {target_comm!r}."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Move all Discord and Twitch community preferences over to the new profile
|
||||||
|
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
|
||||||
|
profileid=source_comm.communityid
|
||||||
|
).set(communityid=target_comm.communityid)
|
||||||
|
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
|
||||||
|
|
||||||
|
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
|
||||||
|
communityid=source_comm.communityid
|
||||||
|
).set(communityid=target_comm.communityid)
|
||||||
|
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
|
||||||
|
|
||||||
|
# And then mark the old community as migrated
|
||||||
|
await source_comm.update(migrated=target_comm.communityid)
|
||||||
|
results.append("Marking old community as migrated.. finished!")
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def fetch_community_by_id(self, community_id: int) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch a Community by the given id.
|
||||||
|
"""
|
||||||
|
return await Community.fetch(self.bot, community_id=community_id)
|
||||||
|
|
||||||
|
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch or create a Community from the provided discord guild.
|
||||||
|
"""
|
||||||
|
comm = await Community.fetch_from_discordid(self.bot, guild.id)
|
||||||
|
if comm is None:
|
||||||
|
comm = await Community.create_from_discord(self.bot, guild)
|
||||||
|
return comm
|
||||||
|
|
||||||
|
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch or create a Community from the provided twitch account.
|
||||||
|
"""
|
||||||
|
community = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
if community is None:
|
||||||
|
community = await Community.create_from_twitch(self.bot, user)
|
||||||
|
return community
|
||||||
|
|
||||||
|
# ----- Profile Commands -----
|
||||||
|
@cmds.hybrid_group(
|
||||||
|
name='profiles',
|
||||||
|
description="Base comand group for user profiles."
|
||||||
|
)
|
||||||
|
async def profiles_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@profiles_grp.group(
|
||||||
|
name='link',
|
||||||
|
description="Base command group for linking profiles"
|
||||||
|
)
|
||||||
|
async def profiles_link_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@profiles_link_grp.command(
|
||||||
|
name='twitch',
|
||||||
|
description="Link a twitch account to your current profile."
|
||||||
|
)
|
||||||
|
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
# Ask the user to go through auth to get their userid
|
||||||
|
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||||
|
flow = await auth_cog.start_auth()
|
||||||
|
message = await ctx.reply(
|
||||||
|
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
|
||||||
|
"to Twitch."
|
||||||
|
)
|
||||||
|
authrow = await flow.run()
|
||||||
|
await message.edit(
|
||||||
|
content="Authentication Complete! Beginning profile merge..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||||
|
# if not results:
|
||||||
|
# logger.error(
|
||||||
|
# f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
# )
|
||||||
|
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# user = results[0]
|
||||||
|
try:
|
||||||
|
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
logger.error(
|
||||||
|
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieve author's profile if it exists
|
||||||
|
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
|
||||||
|
|
||||||
|
# Check if the twitch-side user has a profile
|
||||||
|
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
|
||||||
|
if author_profile and source_profile is None:
|
||||||
|
# All we need to do is attach the twitch row
|
||||||
|
await author_profile.attach_twitch(user.id)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully added Twitch account **{user.display_name}**! There was no profile data to merge."
|
||||||
|
)
|
||||||
|
elif source_profile and author_profile is None:
|
||||||
|
# Attach the discord row to the profile
|
||||||
|
await source_profile.attach_discord(ctx.author.id)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully connected to Twitch profile **{user.display_name}**! There was no profile data to merge."
|
||||||
|
)
|
||||||
|
elif source_profile is None and author_profile is None:
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
|
||||||
|
await profile.attach_twitch(user.id)
|
||||||
|
|
||||||
|
await message.edit(
|
||||||
|
content=f"Opened a new user profile for you and linked Twitch account **{user.display_name}**."
|
||||||
|
)
|
||||||
|
elif author_profile.profileid == source_profile.profileid:
|
||||||
|
await message.edit(
|
||||||
|
content=f"The Twitch account **{user.display_name}** is already linked to your profile!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Migrate the existing profile data to the new profiles
|
||||||
|
try:
|
||||||
|
results = await self.migrate_profile(source_profile, author_profile)
|
||||||
|
except Exception:
|
||||||
|
await ctx.error_reply(
|
||||||
|
"An issue was encountered while merging your account profiles!\n"
|
||||||
|
"Migration rolled back, no data has been lost.\n"
|
||||||
|
"The developer has been notified. Please try again later!"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
content = '\n'.join((
|
||||||
|
"## Connecting Twitch account and merging profiles...",
|
||||||
|
*results,
|
||||||
|
"**Successfully linked account and merged profile data!**"
|
||||||
|
))
|
||||||
|
await message.edit(content=content)
|
||||||
|
|
||||||
|
# ----- Community Commands -----
|
||||||
|
@cmds.hybrid_group(
|
||||||
|
name='community',
|
||||||
|
description="Base comand group for community profiles."
|
||||||
|
)
|
||||||
|
async def community_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@community_grp.group(
|
||||||
|
name='link',
|
||||||
|
description="Base command group for linking communities"
|
||||||
|
)
|
||||||
|
async def community_link_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@community_link_grp.command(
|
||||||
|
name='twitch',
|
||||||
|
description="Link a twitch account to this community."
|
||||||
|
)
|
||||||
|
@appcmds.guild_only()
|
||||||
|
@appcmds.default_permissions(manage_guild=True)
|
||||||
|
async def comm_link_twitch_cmd(self, ctx: LionContext):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
assert ctx.guild is not None
|
||||||
|
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
if not ctx.author.guild_permissions.manage_guild:
|
||||||
|
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ask the user to go through auth to get their userid
|
||||||
|
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||||
|
flow = await auth_cog.start_auth(
|
||||||
|
scopes=[
|
||||||
|
AuthScope.CHAT_EDIT,
|
||||||
|
AuthScope.CHAT_READ,
|
||||||
|
AuthScope.MODERATION_READ,
|
||||||
|
AuthScope.CHANNEL_BOT,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
message = await ctx.reply(
|
||||||
|
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
|
||||||
|
)
|
||||||
|
authrow = await flow.run()
|
||||||
|
await message.edit(
|
||||||
|
content="Authentication Complete! Beginning community profile merge..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||||
|
# if not results:
|
||||||
|
# logger.error(
|
||||||
|
# f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
# )
|
||||||
|
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
# return
|
||||||
|
|
||||||
|
# user = results[0]
|
||||||
|
try:
|
||||||
|
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
logger.error(
|
||||||
|
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Retrieve author's profile if it exists
|
||||||
|
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
|
||||||
|
|
||||||
|
# Check if the twitch-side user has a profile
|
||||||
|
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
|
||||||
|
if guild_comm and twitch_comm is None:
|
||||||
|
# All we need to do is attach the twitch row
|
||||||
|
await guild_comm.attach_twitch(user.id)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully linked Twitch channel **{user.display_name}**! There was no community data to merge."
|
||||||
|
)
|
||||||
|
elif twitch_comm and guild_comm is None:
|
||||||
|
# Attach the discord row to the profile
|
||||||
|
await twitch_comm.attach_discord(ctx.guild.id)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully connected to Twitch channel **{user.display_name}**!"
|
||||||
|
)
|
||||||
|
elif twitch_comm is None and guild_comm is None:
|
||||||
|
profile = await Community.create_from_discord(self.bot, ctx.guild)
|
||||||
|
await profile.attach_twitch(user.id)
|
||||||
|
|
||||||
|
await message.edit(
|
||||||
|
content=f"Created a new community for this server and linked Twitch account **{user.display_name}**."
|
||||||
|
)
|
||||||
|
elif guild_comm.communityid == twitch_comm.communityid:
|
||||||
|
await message.edit(
|
||||||
|
content=f"This server is already linked to the Twitch channel **{user.display_name}**!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Migrate the existing profile data to the new profiles
|
||||||
|
try:
|
||||||
|
results = await self.migrate_community(twitch_comm, guild_comm)
|
||||||
|
except Exception:
|
||||||
|
await ctx.error_reply(
|
||||||
|
"An issue was encountered while merging your community profiles!\n"
|
||||||
|
"Migration rolled back, no data has been lost.\n"
|
||||||
|
"The developer has been notified. Please try again later!"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
content = '\n'.join((
|
||||||
|
"## Connecting Twitch account and merging community profiles...",
|
||||||
|
*results,
|
||||||
|
"**Successfully linked account and merged community data!**"
|
||||||
|
))
|
||||||
|
await message.edit(content=content)
|
||||||
123
src/modules/profiles/community.py
Normal file
123
src/modules/profiles/community.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Community:
|
||||||
|
def __init__(self, bot: LionBot, community_row):
|
||||||
|
self.bot = bot
|
||||||
|
self.row: ProfileData.CommunityRow = community_row
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cog(self):
|
||||||
|
return self.bot.get_cog('ProfileCog')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> ProfileData:
|
||||||
|
return self.cog.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def communityid(self):
|
||||||
|
return self.row.communityid
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Community communityid={self.communityid} row={self.row}>"
|
||||||
|
|
||||||
|
async def attach_discord(self, guildid: int):
|
||||||
|
"""
|
||||||
|
Attach a new discord guild to this community.
|
||||||
|
Assumes the discord guild is not already associated to a community.
|
||||||
|
"""
|
||||||
|
discord_row = await self.data.DiscordCommunityRow.create(
|
||||||
|
communityid=self.communityid,
|
||||||
|
guildid=guildid
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached discord guild {guildid} to community {self!r}"
|
||||||
|
)
|
||||||
|
return discord_row
|
||||||
|
|
||||||
|
async def attach_twitch(self, channelid: str):
|
||||||
|
"""
|
||||||
|
Attach a new Twitch user channel to this community.
|
||||||
|
"""
|
||||||
|
twitch_row = await self.data.TwitchCommunityRow.create(
|
||||||
|
communityid=self.communityid,
|
||||||
|
channelid=str(channelid)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached twitch channel {channelid} to community {self!r}"
|
||||||
|
)
|
||||||
|
return twitch_row
|
||||||
|
|
||||||
|
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Discord guild rows associated to this community.
|
||||||
|
"""
|
||||||
|
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
|
||||||
|
|
||||||
|
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Twitch user rows associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
|
||||||
|
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
|
||||||
|
if community_row is None:
|
||||||
|
raise ValueError("Provided community_id does not exist.")
|
||||||
|
return cls(bot, community_row)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new empty community with the given initial arguments.
|
||||||
|
|
||||||
|
Communities should usually be created using `create_from_discord` or `create_from_twitch`
|
||||||
|
to correctly setup initial preferences (e.g. name, avatar).
|
||||||
|
"""
|
||||||
|
# Create a new community
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
|
||||||
|
return await cls.fetch(bot, row.communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new community using the given Discord guild as a base.
|
||||||
|
"""
|
||||||
|
self = await cls.create(bot, **kwargs)
|
||||||
|
await self.attach_discord(guild.id)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Twitch channel user as a base.
|
||||||
|
The provided `user` must have an `id` attribute.
|
||||||
|
"""
|
||||||
|
self = await cls.create(bot, **kwargs)
|
||||||
|
await self.attach_twitch(str(user.id))
|
||||||
|
return self
|
||||||
158
src/modules/profiles/data.py
Normal file
158
src/modules/profiles/data.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from data import Registry, RowModel
|
||||||
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileData(Registry):
|
||||||
|
class UserProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE user_profiles(
|
||||||
|
profileid SERIAL PRIMARY KEY,
|
||||||
|
nickname TEXT,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'user_profiles'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
profileid = Integer(primary=True)
|
||||||
|
nickname = String()
|
||||||
|
migrated = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE profiles_discord(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'profiles_discord'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
linkid = Integer(primary=True)
|
||||||
|
profileid = Integer()
|
||||||
|
userid = Integer()
|
||||||
|
created_at = Integer()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_profile(cls, profileid: int):
|
||||||
|
rows = await cls.fetch_where(profiled=profileid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE profiles_twitch(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'profiles_twitch'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
linkid = Integer(primary=True)
|
||||||
|
profileid = Integer()
|
||||||
|
userid = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_profile(cls, profileid: int):
|
||||||
|
rows = await cls.fetch_where(profiled=profileid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class CommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities(
|
||||||
|
communityid SERIAL PRIMARY KEY,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
communityid = Integer(primary=True)
|
||||||
|
migrated = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
class DiscordCommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities_discord(
|
||||||
|
guildid BIGINT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities_discord'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
guildid = Integer(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
linked_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_community(cls, communityid: int):
|
||||||
|
rows = await cls.fetch_where(communityd=communityid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class TwitchCommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities_twitch(
|
||||||
|
channelid TEXT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities_twitch'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
channelid = String(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
linked_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_community(cls, communityid: int):
|
||||||
|
rows = await cls.fetch_where(communityd=communityid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class CommunityMemberRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE community_members(
|
||||||
|
memberid SERIAL PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'community_members'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
memberid = Integer(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
profileid = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
138
src/modules/profiles/profile.py
Normal file
138
src/modules/profiles/profile.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile:
|
||||||
|
def __init__(self, bot: LionBot, profile_row):
|
||||||
|
self.bot = bot
|
||||||
|
self.profile_row: ProfileData.UserProfileRow = profile_row
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cog(self):
|
||||||
|
return self.bot.get_cog('ProfileCog')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> ProfileData:
|
||||||
|
return self.cog.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profileid(self):
|
||||||
|
return self.profile_row.profileid
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
|
||||||
|
|
||||||
|
async def get_name(self) -> Optional[str]:
|
||||||
|
return self.profile_row.nickname
|
||||||
|
|
||||||
|
async def attach_discord(self, userid: int):
|
||||||
|
"""
|
||||||
|
Attach a new discord user to this profile.
|
||||||
|
Assumes the discord user does not itself have a profile.
|
||||||
|
"""
|
||||||
|
discord_row = await self.data.DiscordProfileRow.create(
|
||||||
|
profileid=self.profileid,
|
||||||
|
userid=userid
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached discord user {userid} to profile {self!r}"
|
||||||
|
)
|
||||||
|
return discord_row
|
||||||
|
|
||||||
|
async def attach_twitch(self, userid: str):
|
||||||
|
"""
|
||||||
|
Attach a new Twitch user to this profile.
|
||||||
|
"""
|
||||||
|
twitch_row = await self.data.TwitchProfileRow.create(
|
||||||
|
profileid=self.profileid,
|
||||||
|
userid=userid
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached twitch user {userid} to profile {self!r}"
|
||||||
|
)
|
||||||
|
return twitch_row
|
||||||
|
|
||||||
|
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Discord accounts associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.DiscordProfileRow.fetch_where(
|
||||||
|
profileid=self.profileid
|
||||||
|
).order_by(
|
||||||
|
'created_at'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Twitch accounts associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.TwitchProfileRow.fetch_where(
|
||||||
|
profileid=self.profileid
|
||||||
|
).order_by(
|
||||||
|
'created_at'
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
|
||||||
|
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
|
||||||
|
if profile_row is None:
|
||||||
|
raise ValueError("Provided profile_id does not exist.")
|
||||||
|
return cls(bot, profile_row)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].profileid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].profileid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new empty profile with the given initial arguments.
|
||||||
|
|
||||||
|
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
|
||||||
|
to correctly setup initial profile preferences (e.g. name, avatar).
|
||||||
|
"""
|
||||||
|
# Create a new profile
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
profile_row = await data.UserProfileRow.create(created_at=utc_now())
|
||||||
|
profile = await cls.fetch(bot, profile_row.profileid)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Discord user as a base.
|
||||||
|
"""
|
||||||
|
kwargs.setdefault('nickname', user.name)
|
||||||
|
profile = await cls.create(bot, **kwargs)
|
||||||
|
await profile.attach_discord(user.id)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Twitch user as a base.
|
||||||
|
|
||||||
|
Assumes the provided `user` has `id` and `name` attributes.
|
||||||
|
"""
|
||||||
|
kwargs.setdefault('nickname', user.name)
|
||||||
|
profile = await cls.create(bot, **kwargs)
|
||||||
|
await profile.attach_twitch(str(user.id))
|
||||||
|
return profile
|
||||||
9
src/twitch/__init__.py
Normal file
9
src/twitch/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from .cog import TwitchAuthCog
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
await bot.add_cog(TwitchAuthCog(bot))
|
||||||
|
|
||||||
50
src/twitch/authclient.py
Normal file
50
src/twitch/authclient.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Testing client for the twitch AuthServer.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
|
||||||
|
from meta.config import conf
|
||||||
|
|
||||||
|
|
||||||
|
URI = "http://localhost:3000/twiauth/confirm"
|
||||||
|
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Load in client id and secret
|
||||||
|
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
|
||||||
|
url = auth.return_auth_url()
|
||||||
|
|
||||||
|
# Post url to user
|
||||||
|
print(url)
|
||||||
|
|
||||||
|
# Send listen request to server
|
||||||
|
# Wait for listen request
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
|
||||||
|
await ws.send_json({'state': auth.state})
|
||||||
|
result = await ws.receive_json()
|
||||||
|
|
||||||
|
# Hopefully get back code, print the response
|
||||||
|
print(f"Recieved: {result}")
|
||||||
|
|
||||||
|
# Authorise with code and client details
|
||||||
|
tokens = await auth.authenticate(user_token=result['code'])
|
||||||
|
if tokens:
|
||||||
|
token, refresh = tokens
|
||||||
|
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
|
||||||
|
print(f"Authorised!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
86
src/twitch/authserver.py
Normal file
86
src/twitch/authserver.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
|
||||||
|
|
||||||
|
|
||||||
|
class AuthServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.listeners = {}
|
||||||
|
|
||||||
|
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
args = request.query
|
||||||
|
if 'state' not in args:
|
||||||
|
raise web.HTTPBadRequest(text="No state provided.")
|
||||||
|
if args['state'] not in self.listeners:
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state.")
|
||||||
|
self.listeners[args['state']].set_result(dict(args))
|
||||||
|
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
|
||||||
|
|
||||||
|
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
_reqid = str(uuid.uuid1())
|
||||||
|
reqid.set(_reqid)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
|
||||||
|
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
# Get the listen request data
|
||||||
|
try:
|
||||||
|
listen_req = await ws.receive_json(timeout=60)
|
||||||
|
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
|
||||||
|
if 'state' not in listen_req:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Listen request must include state string.")
|
||||||
|
elif listen_req['state'] in self.listeners:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state string.")
|
||||||
|
except ValueError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except TypeError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
|
||||||
|
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
|
||||||
|
raise web.HTTPInternalServerError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fut = self.listeners[listen_req['state']] = asyncio.Future()
|
||||||
|
result = await asyncio.wait_for(fut, timeout=120)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
|
||||||
|
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
|
||||||
|
finally:
|
||||||
|
self.listeners.pop(listen_req['state'], None)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
|
||||||
|
await ws.send_json(result)
|
||||||
|
await ws.close()
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
app = web.Application()
|
||||||
|
server = AuthServer()
|
||||||
|
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
|
||||||
|
app.router.add_get("/twiauth/listen", server.handle_listen_request)
|
||||||
|
|
||||||
|
logger.info("App setup and configured. Starting now.")
|
||||||
|
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
main(sys.argv)
|
||||||
114
src/twitch/cog.py
Normal file
114
src/twitch/cog.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.twitch import AuthType, Twitch
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
from data.queries import ORDER
|
||||||
|
from meta import LionCog, LionBot, CrocBot
|
||||||
|
from meta.LionContext import LionContext
|
||||||
|
from twitch.userflow import UserAuthFlow
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from . import logger
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthCog(LionCog):
|
||||||
|
DEFAULT_SCOPES = []
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = bot.db.load_registry(TwitchAuthData())
|
||||||
|
|
||||||
|
self.client_cache = {}
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
# ----- Auth API -----
|
||||||
|
|
||||||
|
async def fetch_client_for(self, userid: str):
|
||||||
|
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||||
|
if authrow is None:
|
||||||
|
# TODO: Some user authentication error
|
||||||
|
self.client_cache.pop(userid, None)
|
||||||
|
raise ValueError("Requested user is not authenticated.")
|
||||||
|
if (twitch := self.client_cache.get(userid)) is None:
|
||||||
|
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||||
|
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
authscopes = [AuthScope(scope) for scope in scopes]
|
||||||
|
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
|
||||||
|
self.client_cache[userid] = twitch
|
||||||
|
return twitch
|
||||||
|
|
||||||
|
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the given userid is authorised.
|
||||||
|
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||||
|
"""
|
||||||
|
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||||
|
if authrow:
|
||||||
|
if scopes:
|
||||||
|
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
desired = {scope.value for scope in scopes}
|
||||||
|
has_auth = desired.issubset(has_scopes)
|
||||||
|
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
|
||||||
|
else:
|
||||||
|
has_auth = True
|
||||||
|
else:
|
||||||
|
has_auth = False
|
||||||
|
return has_auth
|
||||||
|
|
||||||
|
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||||
|
"""
|
||||||
|
Start the user authentication flow for the given userid.
|
||||||
|
Will request the given scopes along with the default ones and any existing scopes.
|
||||||
|
"""
|
||||||
|
self.client_cache.pop(userid, None)
|
||||||
|
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
existing = map(AuthScope, existing_strs)
|
||||||
|
to_request = set(existing).union(scopes)
|
||||||
|
return await self.start_auth(to_request)
|
||||||
|
|
||||||
|
async def start_auth(self, scopes = []):
|
||||||
|
# TODO: Work out a way to just clone the current twitch object
|
||||||
|
# Or can we otherwise build UserAuthenticator without app auth?
|
||||||
|
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||||
|
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||||
|
await flow.setup()
|
||||||
|
|
||||||
|
return flow
|
||||||
|
|
||||||
|
# ----- Commands -----
|
||||||
|
@cmds.hybrid_command(name='auth')
|
||||||
|
async def cmd_auth(self, ctx: LionContext):
|
||||||
|
if ctx.interaction:
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
flow = await self.start_auth()
|
||||||
|
await ctx.reply(flow.auth.return_auth_url())
|
||||||
|
await flow.run()
|
||||||
|
await ctx.reply("Authentication Complete!")
|
||||||
|
|
||||||
|
@cmds.hybrid_command(name='modauth')
|
||||||
|
async def cmd_modauth(self, ctx: LionContext):
|
||||||
|
if ctx.interaction:
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
scopes = [
|
||||||
|
AuthScope.MODERATOR_READ_FOLLOWERS,
|
||||||
|
AuthScope.CHANNEL_READ_REDEMPTIONS,
|
||||||
|
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
|
||||||
|
]
|
||||||
|
flow = await self.start_auth(scopes=scopes)
|
||||||
|
await ctx.reply(flow.auth.return_auth_url())
|
||||||
|
await flow.run()
|
||||||
|
await ctx.reply("Authentication Complete!")
|
||||||
79
src/twitch/data.py
Normal file
79
src/twitch/data.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthData(Registry):
|
||||||
|
class UserAuthRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
obtained_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'twitch_user_auth'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = Integer(primary=True)
|
||||||
|
access_token = String()
|
||||||
|
refresh_token = String()
|
||||||
|
expires_at = Timestamp()
|
||||||
|
obtained_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_user_auth(
|
||||||
|
cls, userid: str, token: str, refresh: str,
|
||||||
|
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||||
|
scopes: list[str]
|
||||||
|
):
|
||||||
|
if cls._connector is None:
|
||||||
|
raise ValueError("Attempting to use uninitialised Registry.")
|
||||||
|
async with cls._connector.connection() as conn:
|
||||||
|
cls._connector.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
# Clear row for this userid
|
||||||
|
await cls.table.delete_where(userid=userid)
|
||||||
|
|
||||||
|
# Insert new user row
|
||||||
|
row = await cls.create(
|
||||||
|
userid=userid,
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh,
|
||||||
|
expires_at=expires_at,
|
||||||
|
obtained_at=obtained_at
|
||||||
|
)
|
||||||
|
# Insert new scope rows
|
||||||
|
if scopes:
|
||||||
|
await TwitchAuthData.user_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in scopes)
|
||||||
|
)
|
||||||
|
return row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get a list of scopes stored for the given user.
|
||||||
|
Will return an empty list if the user is not authenticated.
|
||||||
|
"""
|
||||||
|
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||||
|
|
||||||
|
return [row['scope'] for row in rows] if rows else []
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
"""
|
||||||
|
user_scopes = Table('twitch_user_scopes')
|
||||||
0
src/twitch/lib.py
Normal file
0
src/twitch/lib.py
Normal file
88
src/twitch/userflow.py
Normal file
88
src/twitch/userflow.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator, validate_token
|
||||||
|
from twitchAPI.type import AuthType
|
||||||
|
from twitchio.client import asyncio
|
||||||
|
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
from . import logger
|
||||||
|
|
||||||
|
class UserAuthFlow:
|
||||||
|
auth: UserAuthenticator
|
||||||
|
data: TwitchAuthData
|
||||||
|
auth_ws: str
|
||||||
|
|
||||||
|
def __init__(self, data, auth, auth_ws):
|
||||||
|
self.auth = auth
|
||||||
|
self.data = data
|
||||||
|
self.auth_ws = auth_ws
|
||||||
|
|
||||||
|
self._setup_done = asyncio.Event()
|
||||||
|
self._comm_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""
|
||||||
|
Establishes websocket connection to the AuthServer,
|
||||||
|
and requests listening for the given state.
|
||||||
|
Propagates any exceptions that occur during connection setup.
|
||||||
|
"""
|
||||||
|
if self._setup_done.is_set():
|
||||||
|
raise ValueError("UserAuthFlow is already set up.")
|
||||||
|
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
|
||||||
|
await self._setup_done.wait()
|
||||||
|
if self._comm_task.done() and (exc := self._comm_task.exception()):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def _communicate(self):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect(self.auth_ws) as ws:
|
||||||
|
await ws.send_json({'state': self.auth.state})
|
||||||
|
self._setup_done.set()
|
||||||
|
return await ws.receive_json()
|
||||||
|
|
||||||
|
async def run(self) -> TwitchAuthData.UserAuthRow:
|
||||||
|
if not self._setup_done.is_set():
|
||||||
|
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||||
|
if self._comm_task is None:
|
||||||
|
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||||
|
|
||||||
|
result = await self._comm_task
|
||||||
|
if result.get('error', None):
|
||||||
|
# TODO Custom auth errors
|
||||||
|
# This is only documented to occur when the user denies the auth
|
||||||
|
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||||
|
|
||||||
|
if result.get('state', None) != self.auth.state:
|
||||||
|
# This should never happen unless the authserver has its wires crossed somehow,
|
||||||
|
# or the connection has been tampered with.
|
||||||
|
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||||
|
logger.critical(
|
||||||
|
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
|
||||||
|
)
|
||||||
|
raise SafeCancellation(
|
||||||
|
"Could not complete authentication! Invalid server response."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now assume result has a valid code
|
||||||
|
# Exchange code for an auth token and a refresh token
|
||||||
|
# Ignore type here, authenticate returns None if a callback function has been given.
|
||||||
|
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
|
||||||
|
|
||||||
|
# Fetch the associated userid and basic info
|
||||||
|
v_result = await validate_token(token)
|
||||||
|
userid = v_result['user_id']
|
||||||
|
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||||
|
|
||||||
|
# Save auth data
|
||||||
|
return await self.data.UserAuthRow.update_user_auth(
|
||||||
|
userid=userid, token=token, refresh=refresh,
|
||||||
|
expires_at=expiry, obtained_at=utc_now(),
|
||||||
|
scopes=[scope.value for scope in self.auth.scopes]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user