From bc073363b9ffd514ed3cc0ed7f47bb519375c687 Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Sep 2024 10:59:13 +1000 Subject: [PATCH 1/9] feat: Start merged profiles and communities. --- data/schema.sql | 67 ++++++++++++++++ src/modules/profiles/__init__.py | 8 ++ src/modules/profiles/cog.py | 72 +++++++++++++++++ src/modules/profiles/data.py | 134 +++++++++++++++++++++++++++++++ 4 files changed, 281 insertions(+) create mode 100644 src/modules/profiles/__init__.py create mode 100644 src/modules/profiles/cog.py create mode 100644 src/modules/profiles/data.py diff --git a/data/schema.sql b/data/schema.sql index 504ae859..24fd294b 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1485,6 +1485,73 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name -- }}} +-- User and Community Profiles {{{ + +CREATE TABLE user_profiles( + profileid SERIAL PRIMARY KEY, + nickname TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE profiles_discord( + linkid SERIAL PRIMARY KEY, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + userid BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE UNIQUE INDEX profiles_discord_profileid ON profiles_discord (profileid); +CREATE INDEX profiles_discord_userid ON profiles_discord (userid); + +CREATE TABLE profiles_twitch( + linkid SERIAL PRIMARY KEY, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + userid BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); +CREATE INDEX profiles_twitch_userid ON profiles_twitch (userid); + +CREATE TABLE communities( + communityid SERIAL PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE communities_discord( + guildid BIGINT PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, + linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE UNIQUE INDEX communities_discord_communityid ON communities_discord (communityid); + +CREATE TABLE communities_twitch( + channelid BIGINT PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, + linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE UNIQUE INDEX communities_twitch_communityid ON communities_twitch (communityid); + +CREATE TABLE community_members( + memberid SERIAL PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityud) ON DELETE CASCADE ON UPDATE CASCADE, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +) +CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid); +-- }}} + +-- Twitch Auth {{ +CREATE TABLE twitch_tokens( + userid BIGINT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ +); +-- }} + + + + -- Analytics Data {{{ CREATE SCHEMA "analytics"; diff --git a/src/modules/profiles/__init__.py b/src/modules/profiles/__init__.py new file mode 100644 index 00000000..67decbfe --- /dev/null +++ b/src/modules/profiles/__init__.py @@ -0,0 +1,8 @@ +import logging + +logger = logging.getLogger(__name__) + +from .cog import ProfileCog + +async def setup(bot): + await bot.add_cog(ProfileCog(bot)) diff --git a/src/modules/profiles/cog.py b/src/modules/profiles/cog.py new file mode 100644 index 00000000..85beb3af --- /dev/null +++ b/src/modules/profiles/cog.py @@ -0,0 +1,72 @@ +import asyncio +from enum import Enum +from typing import Optional +from datetime import timedelta + +import discord +from discord.ext import commands as cmds +import twitchio +from twitchio.ext import commands + + +from data.queries import ORDER +from meta import LionCog, LionBot, CrocBot +from utils.lib import utc_now +from . import logger +from .data import ProfileData + + +class UserProfile: + def __init__(self): + ... + + +class ProfileCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(ProfileData()) + + async def cog_load(self): + await self.data.init() + + async def cog_check(self, ctx): + return True + + # Profile API + async def fetch_profile_discord(self, userid: int, create=True): + """ + Fetch or create a UserProfile from the given Discord userid. + """ + ... + + async def fetch_profile_twitch(self, userid: int, create=True): + """ + Fetch or create a UserProfile from the given Twitch userid. + """ + ... + + async def fetch_profile(self, profileid: int): + """ + Fetch a UserProfile by the given id. + """ + ... + + async def merge_profiles(self, sourceid: int, targetid: int): + """ + Merge two UserProfiles by id. + Merges the 'sourceid' into the 'targetid'. + """ + ... + + async def fetch_community_discord(self, guildid: int, create=True): + ... + + async def fetch_community_twitch(self, guildid: int, create=True): + ... + + async def fetch_community(self, communityid: int): + ... + + # ----- Profile Commands ----- + + # Link twitch profile diff --git a/src/modules/profiles/data.py b/src/modules/profiles/data.py new file mode 100644 index 00000000..f9af7e42 --- /dev/null +++ b/src/modules/profiles/data.py @@ -0,0 +1,134 @@ +from data import Registry, RowModel +from data.columns import Integer, String, Timestamp + + +class ProfileData(Registry): + class UserProfileRow(RowModel): + """ + Schema + ------ + CREATE TABLE user_profiles( + profileid SERIAL PRIMARY KEY, + nickname TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + _tablename_ = 'user_profiles' + _cache_ = {} + + profileid = Integer(primary=True) + nickname = String() + created_at = Timestamp() + + class DiscordProfileRow(RowModel): + """ + Schema + ------ + CREATE TABLE profiles_discord( + linkid SERIAL PRIMARY KEY, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + userid BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + CREATE UNIQUE INDEX profiles_discord_profileid ON profiles_discord (profileid); + CREATE INDEX profiles_discord_userid ON profiles_discord (userid); + """ + _tablename_ = 'profiles_discord' + _cache_ = {} + + linkid = Integer(primary=True) + profileid = Integer() + userid = Integer() + created_at = Integer() + + class TwitchProfileRow(RowModel): + """ + Schema + ------ + CREATE TABLE profiles_twitch( + linkid SERIAL PRIMARY KEY, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + userid BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); + CREATE INDEX profiles_twitch_userid ON profiles_twitch (userid); + """ + _tablename_ = 'profiles_twitch' + _cache_ = {} + + linkid = Integer(primary=True) + profileid = Integer() + userid = Integer() + created_at = Timestamp() + + class CommunityRow(RowModel): + """ + Schema + ------ + CREATE TABLE communities( + communityid SERIAL PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """ + _tablename_ = 'communities' + _cache_ = {} + + communityid = Integer(primary=True) + created_at = Timestamp() + + class DiscordCommunityRow(RowModel): + """ + Schema + ------ + CREATE TABLE communities_discord( + guildid BIGINT PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, + linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + CREATE UNIQUE INDEX communities_discord_communityid ON communities_discord (communityid); + """ + _tablename_ = 'communities_discord' + _cache_ = {} + + guildid = Integer(primary=True) + communityid = Integer() + linked_at = Timestamp() + + class TwitchCommunityRow(RowModel): + """ + Schema + ------ + CREATE TABLE communities_twitch( + channelid BIGINT PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, + linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + CREATE UNIQUE INDEX communities_twitch_communityid ON communities_twitch (communityid); + """ + _tablename_ = 'communities_twitch' + _cache_ = {} + + channelid = Integer(primary=True) + communityid = Integer() + linked_at = Timestamp() + + class CommunityMemberRow(RowModel): + """ + Schema + ------ + CREATE TABLE community_members( + memberid SERIAL PRIMARY KEY, + communityid INTEGER NOT NULL REFERENCES communities (communityud) ON DELETE CASCADE ON UPDATE CASCADE, + profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ) + CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid); + """ + _tablename_ = 'community_members' + _cache_ = {} + + memberid = Integer(primary=True) + communityid = Integer() + profileid = Integer() + created_at = Timestamp() From 41f755795f029f086730049c71a2760a9b4852af Mon Sep 17 00:00:00 2001 From: Interitio Date: Fri, 6 Sep 2024 10:59:47 +1000 Subject: [PATCH 2/9] feat: Start twitch user auth module. --- src/twitch/__init__.py | 9 +++++++++ src/twitch/authserver.py | 7 +++++++ src/twitch/cog.py | 31 +++++++++++++++++++++++++++++++ src/twitch/data.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+) create mode 100644 src/twitch/__init__.py create mode 100644 src/twitch/authserver.py create mode 100644 src/twitch/cog.py create mode 100644 src/twitch/data.py diff --git a/src/twitch/__init__.py b/src/twitch/__init__.py new file mode 100644 index 00000000..53e33e89 --- /dev/null +++ b/src/twitch/__init__.py @@ -0,0 +1,9 @@ +import logging + +logger = logging.getLogger(__name__) + +from .cog import TwitchAuthCog + +async def setup(bot): + await bot.add_cog(TwitchAuthCog(bot)) + diff --git a/src/twitch/authserver.py b/src/twitch/authserver.py new file mode 100644 index 00000000..e87c433c --- /dev/null +++ b/src/twitch/authserver.py @@ -0,0 +1,7 @@ +""" +We want to open an aiohttp server and listen on a configured port. +When we get a request, we validate it to be 'of twitch form', +parse out the error or access token, state, etc, and then pass that information on. + +Passing on maybe done through webhook server? +""" diff --git a/src/twitch/cog.py b/src/twitch/cog.py new file mode 100644 index 00000000..5dfdc39d --- /dev/null +++ b/src/twitch/cog.py @@ -0,0 +1,31 @@ +import asyncio +from enum import Enum +from typing import Optional +from datetime import timedelta + +import discord +from discord.ext import commands as cmds +import twitchio +from twitchio.ext import commands + + +from data.queries import ORDER +from meta import LionCog, LionBot, CrocBot +from utils.lib import utc_now +from . import logger +from .data import TwitchAuthData + + +class TwitchAuthCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(TwitchAuthData()) + + async def cog_load(self): + await self.data.init() + + # ----- Auth API ----- + + async def fetch_client_for(self, userid: int): + ... + diff --git a/src/twitch/data.py b/src/twitch/data.py new file mode 100644 index 00000000..eed13589 --- /dev/null +++ b/src/twitch/data.py @@ -0,0 +1,28 @@ +from data import Registry, RowModel +from data.columns import Integer, String, Timestamp + + +class TwitchAuthData(Registry): + class UserAuthRow(RowModel): + """ + Schema + ------ + CREATE TABLE twitch_tokens( + userid BIGINT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ + ); + + """ + _tablename_ = 'twitch_tokens' + _cache_ = {} + + userid = Integer(primary=True) + access_token = String() + expires_at = Timestamp() + refresh_token = String() + obtained_at = Timestamp() + +# TODO: Scopes From 44d6d7749448cb1098c229a795dbdb11b9dd0cd4 Mon Sep 17 00:00:00 2001 From: Interitio Date: Mon, 23 Sep 2024 15:56:18 +1000 Subject: [PATCH 3/9] feat(twitch): Add authentication server. --- src/twitch/authclient.py | 50 ++++++++++++++++++++++ src/twitch/authserver.py | 91 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 src/twitch/authclient.py diff --git a/src/twitch/authclient.py b/src/twitch/authclient.py new file mode 100644 index 00000000..509b080c --- /dev/null +++ b/src/twitch/authclient.py @@ -0,0 +1,50 @@ +""" +Testing client for the twitch AuthServer. +""" +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd())) +sys.path.insert(0, os.path.join(os.getcwd(), "src")) + +import asyncio +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.type import AuthScope + +from meta.config import conf + + +URI = "http://localhost:3000/twiauth/confirm" +TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ] + +async def main(): + # Load in client id and secret + twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret']) + auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI) + url = auth.return_auth_url() + + # Post url to user + print(url) + + # Send listen request to server + # Wait for listen request + async with aiohttp.ClientSession() as session: + async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws: + await ws.send_json({'state': auth.state}) + result = await ws.receive_json() + + # Hopefully get back code, print the response + print(f"Recieved: {result}") + + # Authorise with code and client details + tokens = await auth.authenticate(user_token=result['code']) + if tokens: + token, refresh = tokens + await twitch.set_user_authentication(token, TARGET_SCOPE, refresh) + print(f"Authorised!") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/twitch/authserver.py b/src/twitch/authserver.py index e87c433c..b26c5953 100644 --- a/src/twitch/authserver.py +++ b/src/twitch/authserver.py @@ -1,7 +1,86 @@ -""" -We want to open an aiohttp server and listen on a configured port. -When we get a request, we validate it to be 'of twitch form', -parse out the error or access token, state, etc, and then pass that information on. +import logging +import uuid +import asyncio +from contextvars import ContextVar -Passing on maybe done through webhook server? -""" +import aiohttp +from aiohttp import web + +logger = logging.getLogger(__name__) +reqid: ContextVar[str] = ContextVar('reqid', default='ROOT') + + +class AuthServer: + def __init__(self): + self.listeners = {} + + async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse: + args = request.query + if 'state' not in args: + raise web.HTTPBadRequest(text="No state provided.") + if args['state'] not in self.listeners: + raise web.HTTPBadRequest(text="Invalid state.") + self.listeners[args['state']].set_result(dict(args)) + return web.Response(text="Authorisation complete! You may now close this page and return to the application.") + + async def handle_listen_request(self, request: web.Request) -> web.StreamResponse: + _reqid = str(uuid.uuid1()) + reqid.set(_reqid) + + logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}") + + ws = web.WebSocketResponse() + await ws.prepare(request) + + # Get the listen request data + try: + listen_req = await ws.receive_json(timeout=60) + logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}") + if 'state' not in listen_req: + logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.") + raise web.HTTPBadRequest(text="Listen request must include state string.") + elif listen_req['state'] in self.listeners: + logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.") + raise web.HTTPBadRequest(text="Invalid state string.") + except ValueError: + logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except TypeError: + logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.") + raise web.HTTPBadRequest(text="Request must be a JSON formatted string.") + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.") + raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.") + except Exception: + logger.exception(f"[reqid: {_reqid}] Unknown exception.") + raise web.HTTPInternalServerError() + + try: + fut = self.listeners[listen_req['state']] = asyncio.Future() + result = await asyncio.wait_for(fut, timeout=120) + except asyncio.TimeoutError: + logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.") + raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.") + finally: + self.listeners.pop(listen_req['state'], None) + + logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.") + await ws.send_json(result) + await ws.close() + logger.debug(f"[reqid: {_reqid}] Request completed handling.") + + return ws + +def main(argv): + app = web.Application() + server = AuthServer() + app.router.add_get("/twiauth/confirm", server.handle_twitch_callback) + app.router.add_get("/twiauth/listen", server.handle_listen_request) + + logger.info("App setup and configured. Starting now.") + web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080) + + +if __name__ == '__main__': + import sys + main(sys.argv) From caa907b6d9db4759ea030c0dafa37aecbba6e1de Mon Sep 17 00:00:00 2001 From: Interitio Date: Mon, 23 Sep 2024 15:56:51 +1000 Subject: [PATCH 4/9] feat(twitch): Add UserAuthFlow for user auth. --- src/twitch/userflow.py | 84 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 src/twitch/userflow.py diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py new file mode 100644 index 00000000..11b7fdfd --- /dev/null +++ b/src/twitch/userflow.py @@ -0,0 +1,84 @@ +from typing import Optional + +from aiohttp import web + +import aiohttp +from twitchAPI.twitch import Twitch +from twitchAPI.oauth import UserAuthenticator, validate_token +from twitchAPI.type import AuthType +from twitchio.client import asyncio + +from meta.errors import SafeCancellation +from .data import TwitchAuthData +from . import logger + +class UserAuthFlow: + auth: UserAuthenticator + data: TwitchAuthData + auth_ws: str + + def __init__(self, data, auth, auth_ws): + self.auth = auth + self.data = data + self.auth_ws = auth_ws + + self._setup_done = asyncio.Event() + self._comm_task: Optional[asyncio.Task] = None + + async def setup(self): + """ + Establishes websocket connection to the AuthServer, + and requests listening for the given state. + Propagates any exceptions that occur during connection setup. + """ + if self._setup_done.is_set(): + raise ValueError("UserAuthFlow is already set up.") + self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate') + await self._setup_done.wait() + if self._comm_task.done() and (exc := self._comm_task.exception()): + raise exc + + async def _communicate(self): + async with aiohttp.ClientSession() as session: + async with session.ws_connect(self.auth_ws) as ws: + await ws.send_json({'state': self.auth.state}) + self._setup_done.set() + return await ws.receive_json() + + async def run(self): + if not self._setup_done.is_set(): + raise ValueError("Cannot run UserAuthFlow before setup.") + if self._comm_task is None: + raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") + + result = await self._comm_task + if result['error']: + # TODO Custom auth errors + # This is only documented to occure when the user denies the auth + raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") + + if result['state'] != self.auth.state: + # This should never happen unless the authserver has its wires crossed somehow, + # or the connection has been tampered with. + # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. + logger.critical( + f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG." + ) + raise SafeCancellation( + "Could not complete authentication! Invalid server response." + ) + + # Now assume result has a valid code + # Exchange code for an auth token and a refresh token + # Ignore type here, authenticate returns None if a callback function has been given. + token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore + + # Fetch the associated userid and basic info + v_result = await validate_token(token) + userid = v_result['user_id'] + + # Save auth data + + async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]): + if not self.data._conn: + raise ValueError("Provided registry must be connected.") From 9c738ecb9136cf546aec50269f4b6d86fa143a9b Mon Sep 17 00:00:00 2001 From: Interitio Date: Wed, 25 Sep 2024 02:34:10 +1000 Subject: [PATCH 5/9] feat(twitch): Add basic user authentication flow. --- data/schema.sql | 18 ++++++++++++ src/bot.py | 1 + src/twitch/cog.py | 53 ++++++++++++++++++++++++++++++++++ src/twitch/data.py | 65 +++++++++++++++++++++++++++++++++++++----- src/twitch/lib.py | 0 src/twitch/userflow.py | 16 +++++++---- 6 files changed, 140 insertions(+), 13 deletions(-) create mode 100644 src/twitch/lib.py diff --git a/data/schema.sql b/data/schema.sql index 504ae859..345a997e 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1485,6 +1485,24 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name -- }}} +-- Twitch User Auth {{{ +CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + refresh_token TEXT NOT NULL, + obtained_at TIMESTAMPTZ +); + + +CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT +); +CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + +-- }}} + -- Analytics Data {{{ CREATE SCHEMA "analytics"; diff --git a/src/bot.py b/src/bot.py index 852530f1..db0b312e 100644 --- a/src/bot.py +++ b/src/bot.py @@ -98,6 +98,7 @@ async def main(): config=conf, initial_extensions=[ 'utils', 'core', 'analytics', + 'twitch', 'modules', 'babel', 'tracking.voice', 'tracking.text', diff --git a/src/twitch/cog.py b/src/twitch/cog.py index 5dfdc39d..b3742b0a 100644 --- a/src/twitch/cog.py +++ b/src/twitch/cog.py @@ -5,18 +5,26 @@ from datetime import timedelta import discord from discord.ext import commands as cmds + +from twitchAPI.oauth import UserAuthenticator +from twitchAPI.twitch import AuthType, Twitch +from twitchAPI.type import AuthScope import twitchio from twitchio.ext import commands from data.queries import ORDER from meta import LionCog, LionBot, CrocBot +from meta.LionContext import LionContext +from twitch.userflow import UserAuthFlow from utils.lib import utc_now from . import logger from .data import TwitchAuthData class TwitchAuthCog(LionCog): + DEFAULT_SCOPES = [] + def __init__(self, bot: LionBot): self.bot = bot self.data = bot.db.load_registry(TwitchAuthData()) @@ -29,3 +37,48 @@ class TwitchAuthCog(LionCog): async def fetch_client_for(self, userid: int): ... + async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool: + """ + Checks whether the given userid is authorised. + If 'scopes' is given, will also check the user has all of the given scopes. + """ + authrow = await self.data.UserAuthRow.fetch(userid) + if authrow: + if scopes: + has_scopes = await self.data.UserAuthRow.get_scopes_for(userid) + has_auth = set(map(str, scopes)).issubset(has_scopes) + else: + has_auth = True + else: + has_auth = False + return has_auth + + async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []): + """ + Start the user authentication flow for the given userid. + Will request the given scopes along with the default ones and any existing scopes. + """ + existing_strs = await self.data.UserAuthRow.get_scopes_for(userid) + existing = map(AuthScope, existing_strs) + to_request = set(existing).union(scopes) + return await self.start_auth(to_request) + + async def start_auth(self, scopes = []): + # TODO: Work out a way to just clone the current twitch object + # Or can we otherwise build UserAuthenticator without app auth? + twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret']) + auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri']) + flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url']) + await flow.setup() + + return flow + + # ----- Commands ----- + @cmds.hybrid_command(name='auth') + async def cmd_auth(self, ctx: LionContext): + if ctx.interaction: + await ctx.interaction.response.defer(ephemeral=True) + flow = await self.start_auth() + await ctx.reply(flow.auth.return_auth_url()) + await flow.run() + await ctx.reply("Authentication Complete!") diff --git a/src/twitch/data.py b/src/twitch/data.py index eed13589..94b45c37 100644 --- a/src/twitch/data.py +++ b/src/twitch/data.py @@ -1,4 +1,6 @@ -from data import Registry, RowModel +import datetime as dt + +from data import Registry, RowModel, Table from data.columns import Integer, String, Timestamp @@ -7,22 +9,71 @@ class TwitchAuthData(Registry): """ Schema ------ - CREATE TABLE twitch_tokens( - userid BIGINT PRIMARY KEY, + CREATE TABLE twitch_user_auth( + userid TEXT PRIMARY KEY, access_token TEXT NOT NULL, expires_at TIMESTAMPTZ NOT NULL, refresh_token TEXT NOT NULL, obtained_at TIMESTAMPTZ ); - """ - _tablename_ = 'twitch_tokens' + _tablename_ = 'twitch_user_auth' _cache_ = {} userid = Integer(primary=True) access_token = String() - expires_at = Timestamp() refresh_token = String() + expires_at = Timestamp() obtained_at = Timestamp() -# TODO: Scopes + @classmethod + async def update_user_auth( + cls, userid: str, token: str, refresh: str, + expires_at: dt.datetime, obtained_at: dt.datetime, + scopes: list[str] + ): + if cls._connector is None: + raise ValueError("Attempting to use uninitialised Registry.") + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # Clear row for this userid + await cls.table.delete_where(userid=userid) + + # Insert new user row + row = await cls.create( + userid=userid, + access_token=token, + refresh_token=refresh, + expires_at=expires_at, + obtained_at=obtained_at + ) + # Insert new scope rows + if scopes: + await TwitchAuthData.user_scopes.insert_many( + ('userid', 'scope'), + *((userid, scope) for scope in scopes) + ) + return row + + @classmethod + async def get_scopes_for(cls, userid: str) -> list[str]: + """ + Get a list of scopes stored for the given user. + Will return an empty list if the user is not authenticated. + """ + rows = await TwitchAuthData.user_scopes.select_where(userid=userid) + + return [row.scope for row in rows] if rows else [] + + + """ + Schema + ------ + CREATE TABLE twitch_user_scopes( + userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE, + scope TEXT + ); + CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); + """ + user_scopes = Table('twitch_token_scopes') diff --git a/src/twitch/lib.py b/src/twitch/lib.py new file mode 100644 index 00000000..e69de29b diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py index 11b7fdfd..ce7d20dc 100644 --- a/src/twitch/userflow.py +++ b/src/twitch/userflow.py @@ -1,4 +1,5 @@ from typing import Optional +import datetime as dt from aiohttp import web @@ -9,6 +10,7 @@ from twitchAPI.type import AuthType from twitchio.client import asyncio from meta.errors import SafeCancellation +from utils.lib import utc_now from .data import TwitchAuthData from . import logger @@ -52,12 +54,12 @@ class UserAuthFlow: raise ValueError("UserAuthFlow running with no comm task! This should be impossible.") result = await self._comm_task - if result['error']: + if result.get('error', None): # TODO Custom auth errors # This is only documented to occure when the user denies the auth raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}") - if result['state'] != self.auth.state: + if result.get('state', None) != self.auth.state: # This should never happen unless the authserver has its wires crossed somehow, # or the connection has been tampered with. # TODO: Consider terminating for safety in this case? Or at least refusing more auth requests. @@ -76,9 +78,11 @@ class UserAuthFlow: # Fetch the associated userid and basic info v_result = await validate_token(token) userid = v_result['user_id'] + expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in']) # Save auth data - - async def save_auth(self, userid: str, token: str, refresh: str, scopes: list[str]): - if not self.data._conn: - raise ValueError("Provided registry must be connected.") + return await self.data.UserAuthRow.update_user_auth( + userid=userid, token=token, refresh=refresh, + expires_at=expiry, obtained_at=utc_now(), + scopes=[scope.value for scope in self.auth.scopes] + ) From 9d0d19d046c889f3de97fa645c6cd73231566b1c Mon Sep 17 00:00:00 2001 From: Interitio Date: Thu, 26 Sep 2024 21:22:42 +1000 Subject: [PATCH 6/9] (profiles): Start internal API. --- data/schema.sql | 9 +++--- src/modules/profiles/cog.py | 53 ++++++++++++++++++++++++++++++++++-- src/modules/profiles/data.py | 33 ++++++++++++++++++---- 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/data/schema.sql b/data/schema.sql index 535389b1..87cf829b 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1505,12 +1505,13 @@ CREATE INDEX profiles_discord_userid ON profiles_discord (userid); CREATE TABLE profiles_twitch( linkid SERIAL PRIMARY KEY, profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, - userid BIGINT NOT NULL, + userid TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); CREATE INDEX profiles_twitch_userid ON profiles_twitch (userid); + CREATE TABLE communities( communityid SERIAL PRIMARY KEY, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() @@ -1524,7 +1525,7 @@ CREATE TABLE communities_discord( CREATE UNIQUE INDEX communities_discord_communityid ON communities_discord (communityid); CREATE TABLE communities_twitch( - channelid BIGINT PRIMARY KEY, + channelid TEXT PRIMARY KEY, communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); @@ -1532,10 +1533,10 @@ CREATE UNIQUE INDEX communities_twitch_communityid ON communities_twitch (commun CREATE TABLE community_members( memberid SERIAL PRIMARY KEY, - communityid INTEGER NOT NULL REFERENCES communities (communityud) ON DELETE CASCADE ON UPDATE CASCADE, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -) +); CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid); -- }}} diff --git a/src/modules/profiles/cog.py b/src/modules/profiles/cog.py index 85beb3af..ed2eb756 100644 --- a/src/modules/profiles/cog.py +++ b/src/modules/profiles/cog.py @@ -7,6 +7,7 @@ import discord from discord.ext import commands as cmds import twitchio from twitchio.ext import commands +from twitchAPI.object.api import TwitchUser from data.queries import ORDER @@ -17,9 +18,53 @@ from .data import ProfileData class UserProfile: - def __init__(self): + def __init__(self, data, profile_row, *, discord_row=None, twitch_row=None): + self.data: ProfileData = data + self.profile_row: ProfileData.UserProfileRow = profile_row + + self.discord_row: Optional[ProfileData.DiscordProfileRow] = discord_row + self.twitch_row: Optional[ProfileData.TwitchProfileRow] = twitch_row + + @property + def profileid(self): + return self.profile_row.profileid + + async def attach_discord(self, user: discord.User | discord.Member): + """ + Attach a new discord user to this profile. + """ + # TODO: Attach whatever other data we want to cache here. + # Currently Lion also caches most of this data + discord_row = await self.data.DiscordProfileRow.create( + profileid=self.profileid, + userid=user.id + ) + + async def attach_twitch(self, user: TwitchUser): + """ + Attach a new Twitch user to this profile. + """ ... + @classmethod + async def fetch_profile( + cls, data: ProfileData, + *, + profile_id: Optional[int] = None, + profile_row: Optional[ProfileData.UserProfileRow] = None, + discord_row: Optional[ProfileData.DiscordProfileRow] = None, + twitch_row: Optional[ProfileData.TwitchProfileRow] = None, + ): + if not any((profile_id, profile_row, discord_row, twitch_row)): + raise ValueError("UserProfile needs an id or a data row to construct.") + if profile_id is None: + profile_id = (profile_row or discord_row or twitch_row).profileid + profile_row = profile_row or await data.UserProfileRow.fetch(profile_id) + discord_row = discord_row or await data.DiscordProfileRow.fetch_profile(profile_id) + twitch_row = twitch_row or await data.TwitchProfileRow.fetch_profile(profile_id) + + return cls(data, profile_row, discord_row=discord_row, twitch_row=twitch_row) + class ProfileCog(LionCog): def __init__(self, bot: LionBot): @@ -37,7 +82,11 @@ class ProfileCog(LionCog): """ Fetch or create a UserProfile from the given Discord userid. """ - ... + # TODO: (Extension) May be context dependent + # Current model assumes profile (one->0..n) discord + discord_row = next(await self.data.DiscordProfileRow.fetch_where(userid=userid), None) + if discord_row is None: + profile_row = await self.data.UserProfileRow.create() async def fetch_profile_twitch(self, userid: int, create=True): """ diff --git a/src/modules/profiles/data.py b/src/modules/profiles/data.py index f9af7e42..dc85eb1e 100644 --- a/src/modules/profiles/data.py +++ b/src/modules/profiles/data.py @@ -41,6 +41,12 @@ class ProfileData(Registry): userid = Integer() created_at = Integer() + @classmethod + async def fetch_profile(cls, profileid: int): + rows = await cls.fetch_where(profiled=profileid) + return next(rows, None) + + class TwitchProfileRow(RowModel): """ Schema @@ -48,7 +54,7 @@ class ProfileData(Registry): CREATE TABLE profiles_twitch( linkid SERIAL PRIMARY KEY, profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, - userid BIGINT NOT NULL, + userid TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); @@ -59,9 +65,14 @@ class ProfileData(Registry): linkid = Integer(primary=True) profileid = Integer() - userid = Integer() + userid = String() created_at = Timestamp() + @classmethod + async def fetch_profile(cls, profileid: int): + rows = await cls.fetch_where(profiled=profileid) + return next(rows, None) + class CommunityRow(RowModel): """ Schema @@ -95,12 +106,17 @@ class ProfileData(Registry): communityid = Integer() linked_at = Timestamp() + @classmethod + async def fetch_community(cls, communityid: int): + rows = await cls.fetch_where(communityd=communityid) + return next(rows, None) + class TwitchCommunityRow(RowModel): """ Schema ------ CREATE TABLE communities_twitch( - channelid BIGINT PRIMARY KEY, + channelid TEXT PRIMARY KEY, communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); @@ -109,20 +125,25 @@ class ProfileData(Registry): _tablename_ = 'communities_twitch' _cache_ = {} - channelid = Integer(primary=True) + channelid = String(primary=True) communityid = Integer() linked_at = Timestamp() + @classmethod + async def fetch_community(cls, communityid: int): + rows = await cls.fetch_where(communityd=communityid) + return next(rows, None) + class CommunityMemberRow(RowModel): """ Schema ------ CREATE TABLE community_members( memberid SERIAL PRIMARY KEY, - communityid INTEGER NOT NULL REFERENCES communities (communityud) ON DELETE CASCADE ON UPDATE CASCADE, + communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() - ) + ); CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid); """ _tablename_ = 'community_members' From 63152f3475af5f25e1dade1ba20660babdef73d6 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sat, 5 Oct 2024 07:50:43 +1000 Subject: [PATCH 7/9] routine: Use ssh url for voicefix submodule. --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index f3a33f00..c02a39a4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,7 +6,7 @@ url = git@github.com:Intery/CafeHelper-Skins.git [submodule "src/modules/voicefix"] path = src/modules/voicefix - url = https://github.com/Intery/StudyLion-voicefix.git + url = git@github.com:Intery/StudyLion-voicefix.git [submodule "src/modules/streamalerts"] path = src/modules/streamalerts url = https://github.com/Intery/StudyLion-streamalerts.git From 92fee23afa6ea4f2ad6f514799788fea19bc9960 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 6 Oct 2024 15:43:49 +1000 Subject: [PATCH 8/9] feat(profiles): Add profile base and users. --- data/schema.sql | 21 ++- src/meta/LionBot.py | 10 ++ src/meta/LionCog.py | 1 + src/modules/__init__.py | 1 + src/modules/profiles/cog.py | 245 +++++++++++++++++++++--------- src/modules/profiles/community.py | 0 src/modules/profiles/data.py | 13 +- src/modules/profiles/profile.py | 124 +++++++++++++++ src/twitch/userflow.py | 2 +- 9 files changed, 332 insertions(+), 85 deletions(-) create mode 100644 src/modules/profiles/community.py create mode 100644 src/modules/profiles/profile.py diff --git a/data/schema.sql b/data/schema.sql index 01898dfc..faf668c8 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1498,10 +1498,19 @@ CREATE INDEX voice_role_channels on voice_roles (channelid); -- }}} -- User and Community Profiles {{{ +DROP TABLE IF EXISTS community_members; +DROP TABLE IF EXISTS communities_twitch; +DROP TABLE IF EXISTS communities_discord; +DROP TABLE IF EXISTS communities; +DROP TABLE IF EXISTS profiles_twitch; +DROP TABLE IF EXISTS profiles_discord; +DROP TABLE IF EXISTS user_profiles; + CREATE TABLE user_profiles( profileid SERIAL PRIMARY KEY, nickname TEXT, + migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); @@ -1511,8 +1520,8 @@ CREATE TABLE profiles_discord( userid BIGINT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX profiles_discord_profileid ON profiles_discord (profileid); -CREATE INDEX profiles_discord_userid ON profiles_discord (userid); +CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid); +CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid); CREATE TABLE profiles_twitch( linkid SERIAL PRIMARY KEY, @@ -1520,8 +1529,8 @@ CREATE TABLE profiles_twitch( userid TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); -CREATE INDEX profiles_twitch_userid ON profiles_twitch (userid); +CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); +CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid); CREATE TABLE communities( @@ -1534,14 +1543,14 @@ CREATE TABLE communities_discord( communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX communities_discord_communityid ON communities_discord (communityid); +CREATE INDEX communities_discord_communityid ON communities_discord (communityid); CREATE TABLE communities_twitch( channelid TEXT PRIMARY KEY, communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX communities_twitch_communityid ON communities_twitch (communityid); +CREATE INDEX communities_twitch_communityid ON communities_twitch (communityid); CREATE TABLE community_members( memberid SERIAL PRIMARY KEY, diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 19904f9f..ebfc2875 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from meta.CrocBot import CrocBot from core.cog import CoreCog from core.config import ConfigCog + from twitch.cog import TwitchAuthCog from tracking.voice.cog import VoiceTrackerCog from tracking.text.cog import TextTrackerCog from modules.config.cog import GuildConfigCog @@ -49,6 +50,7 @@ if TYPE_CHECKING: from modules.topgg.cog import TopggCog from modules.user_config.cog import UserConfigCog from modules.video_channels.cog import VideoCog + from modules.profiles.cog import ProfileCog logger = logging.getLogger(__name__) @@ -142,6 +144,10 @@ class LionBot(Bot): # To make the type checker happy about fetching cogs by name # TODO: Move this to stubs at some point + @overload + def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog': + ... + @overload def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog': ... @@ -154,6 +160,10 @@ class LionBot(Bot): def get_cog(self, name: Literal['VoiceTrackerCog']) -> 'VoiceTrackerCog': ... + @overload + def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog': + ... + @overload def get_cog(self, name: Literal['TextTrackerCog']) -> 'TextTrackerCog': ... diff --git a/src/meta/LionCog.py b/src/meta/LionCog.py index a2b1b625..f1719a6f 100644 --- a/src/meta/LionCog.py +++ b/src/meta/LionCog.py @@ -22,6 +22,7 @@ class LionCog(Cog): cls._placeholder_groups_ = set() cls._twitch_cmds_ = {} cls._twitch_events_ = {} + cls._twitch_events_loaded_ = set() for base in reversed(cls.__mro__): for elem, value in base.__dict__.items(): diff --git a/src/modules/__init__.py b/src/modules/__init__.py index 9e6bb1fd..ddafc4dd 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -2,6 +2,7 @@ this_package = 'modules' active_discord = [ '.sysadmin', + '.profiles', '.config', '.user_config', '.skins', diff --git a/src/modules/profiles/cog.py b/src/modules/profiles/cog.py index ed2eb756..3e0c49a0 100644 --- a/src/modules/profiles/cog.py +++ b/src/modules/profiles/cog.py @@ -1,76 +1,36 @@ import asyncio from enum import Enum -from typing import Optional +from typing import Optional, overload from datetime import timedelta import discord from discord.ext import commands as cmds import twitchio from twitchio.ext import commands +from twitchio import User from twitchAPI.object.api import TwitchUser from data.queries import ORDER -from meta import LionCog, LionBot, CrocBot +from meta import LionCog, LionBot, CrocBot, LionContext +from meta.logger import log_wrap from utils.lib import utc_now from . import logger from .data import ProfileData - - -class UserProfile: - def __init__(self, data, profile_row, *, discord_row=None, twitch_row=None): - self.data: ProfileData = data - self.profile_row: ProfileData.UserProfileRow = profile_row - - self.discord_row: Optional[ProfileData.DiscordProfileRow] = discord_row - self.twitch_row: Optional[ProfileData.TwitchProfileRow] = twitch_row - - @property - def profileid(self): - return self.profile_row.profileid - - async def attach_discord(self, user: discord.User | discord.Member): - """ - Attach a new discord user to this profile. - """ - # TODO: Attach whatever other data we want to cache here. - # Currently Lion also caches most of this data - discord_row = await self.data.DiscordProfileRow.create( - profileid=self.profileid, - userid=user.id - ) - - async def attach_twitch(self, user: TwitchUser): - """ - Attach a new Twitch user to this profile. - """ - ... - - @classmethod - async def fetch_profile( - cls, data: ProfileData, - *, - profile_id: Optional[int] = None, - profile_row: Optional[ProfileData.UserProfileRow] = None, - discord_row: Optional[ProfileData.DiscordProfileRow] = None, - twitch_row: Optional[ProfileData.TwitchProfileRow] = None, - ): - if not any((profile_id, profile_row, discord_row, twitch_row)): - raise ValueError("UserProfile needs an id or a data row to construct.") - if profile_id is None: - profile_id = (profile_row or discord_row or twitch_row).profileid - profile_row = profile_row or await data.UserProfileRow.fetch(profile_id) - discord_row = discord_row or await data.DiscordProfileRow.fetch_profile(profile_id) - twitch_row = twitch_row or await data.TwitchProfileRow.fetch_profile(profile_id) - - return cls(data, profile_row, discord_row=discord_row, twitch_row=twitch_row) +from .profile import UserProfile class ProfileCog(LionCog): def __init__(self, bot: LionBot): self.bot = bot + + assert bot.crocbot is not None + self.crocbot: CrocBot = bot.crocbot self.data = bot.db.load_registry(ProfileData()) + self._profile_migrators = {} + self._comm_migrators = {} + async def cog_load(self): await self.data.init() @@ -78,34 +38,84 @@ class ProfileCog(LionCog): return True # Profile API - async def fetch_profile_discord(self, userid: int, create=True): - """ - Fetch or create a UserProfile from the given Discord userid. - """ - # TODO: (Extension) May be context dependent - # Current model assumes profile (one->0..n) discord - discord_row = next(await self.data.DiscordProfileRow.fetch_where(userid=userid), None) - if discord_row is None: - profile_row = await self.data.UserProfileRow.create() + def add_profile_migrator(self, migrator, name=None): + name = name or migrator.__name__ + self._profile_migrators[name or migrator.__name__] = migrator - async def fetch_profile_twitch(self, userid: int, create=True): - """ - Fetch or create a UserProfile from the given Twitch userid. - """ - ... + logger.info( + f"Added user profile migrator {name}: {migrator}" + ) + return migrator - async def fetch_profile(self, profileid: int): + def del_profile_migrator(self, name: str): + migrator = self._profile_migrators.pop(name, None) + + logger.info( + f"Removed user profile migrator {name}: {migrator}" + ) + + @log_wrap(action="profile migration") + async def migrate_profile(self, source_profile, target_profile) -> list[str]: + logger.info( + f"Beginning user profile migration from {source_profile!r} to {target_profile!r}" + ) + results = [] + # Wrap this in a transaction so if something goes wrong with migration, + # we roll back safely (although this may mess up caches) + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + for name, migrator in self._profile_migrators.items(): + try: + result = await migrator(source_profile, target_profile) + if result: + results.append(result) + except Exception: + logger.exception( + f"Unexpected exception running user profile migrator {name} " + f"migrating {source_profile!r} to {target_profile!r}." + ) + raise + + # Move all Discord and Twitch profile references over to the new profile + discord_rows = await self.data.DiscordProfileRow.table.update_where( + profileid=source_profile.profileid + ).set(profileid=target_profile.profileid) + results.append(f"Migrated {len(discord_rows)} attached discord account(s).") + + twitch_rows = await self.data.TwitchProfileRow.table.update_where( + profileid=source_profile.profileid + ).set(profileid=target_profile.profileid) + results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).") + + # And then mark the old profile as migrated + await source_profile.update(migrate=target_profile.profileid) + results.append("Marking old profile as migrated.. finished!") + return results + + async def fetch_profile_by_id(self, profile_id: int) -> UserProfile: """ Fetch a UserProfile by the given id. """ - ... + return await UserProfile.fetch_profile(self.bot, profile_id=profile_id) - async def merge_profiles(self, sourceid: int, targetid: int): + async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile: """ - Merge two UserProfiles by id. - Merges the 'sourceid' into the 'targetid'. + Fetch or create a UserProfile from the provided discord account. """ - ... + profile = await UserProfile.fetch_from_discordid(user.id) + if profile is None: + profile = await UserProfile.create_from_discord(user) + return profile + + async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile: + """ + Fetch or create a UserProfile from the provided twitch account. + """ + profile = await UserProfile.fetch_from_twitchid(user.id) + if profile is None: + profile = await UserProfile.create_from_twitch(user) + return profile async def fetch_community_discord(self, guildid: int, create=True): ... @@ -118,4 +128,95 @@ class ProfileCog(LionCog): # ----- Profile Commands ----- - # Link twitch profile + @cmds.hybrid_group( + name='profiles', + description="Base comand group for user profiles." + ) + async def profiles_grp(self, ctx: LionContext): + ... + + @profiles_grp.group( + name='link', + description="Base command group for linking profiles" + ) + async def profiles_link_grp(self, ctx: LionContext): + ... + + @profiles_link_grp.command( + name='twitch', + description="Link a twitch account to your current profile." + ) + async def profiles_link_twitch_cmd(self, ctx: LionContext): + if not ctx.interaction: + return + + await ctx.interaction.response.defer(ephemeral=True) + + # Ask the user to go through auth to get their userid + auth_cog = self.bot.get_cog('TwitchAuthCog') + flow = await auth_cog.start_auth() + message = await ctx.reply( + f"Please [click here]({flow.auth.return_auth_url()}) to link your profile " + "to Twitch." + ) + authrow = await flow.run() + await message.edit( + content="Authentication Complete! Beginning profile merge..." + ) + + results = await self.crocbot.fetch_users(ids=[authrow.userid]) + if not results: + logger.error( + f"User {authrow} obtained from Twitch authentication does not exist." + ) + await ctx.error_reply("Sorry, something went wrong. Please try again later!") + + user = results[0] + + # Retrieve author's profile if it exists + author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id) + + # Check if the twitch-side user has a profile + source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id) + + if author_profile and source_profile is None: + # All we need to do is attach the twitch row + await author_profile.attach_twitch(user) + await message.edit( + content=f"Successfully added Twitch account **{user.name}**! There was no profile data to merge." + ) + elif source_profile and author_profile is None: + # Attach the discord row to the profile + await source_profile.attach_discord(ctx.author) + await message.edit( + content=f"Successfully connect to Twitch profile **{user.name}**! There was no profile data to merge." + ) + elif source_profile is None and author_profile is None: + profile = await UserProfile.create_from_discord(self.bot, ctx.author) + await profile.attach_twitch(user) + + await message.edit( + content=f"Opened a new user profile for you and linked Twitch account **{user.name}**." + ) + elif author_profile.profileid == source_profile.profileid: + await message.edit( + content=f"The Twitch account **{user.name}** is already linked to your profile!" + ) + else: + # Migrate the existing profile data to the new profiles + try: + results = await self.migrate_profile(source_profile, author_profile) + except Exception: + await ctx.error_reply( + "An issue was encountered while merging your account profiles!\n" + "Migration rolled back, no data has been lost.\n" + "The developer has been notified. Please try again later!" + ) + raise + + content = '\n'.join(( + "## Connecting Twitch account and merging profiles...", + *results, + "**Successfully linked account and merge profile data!**" + )) + await message.edit(content=content) diff --git a/src/modules/profiles/community.py b/src/modules/profiles/community.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modules/profiles/data.py b/src/modules/profiles/data.py index dc85eb1e..eed48792 100644 --- a/src/modules/profiles/data.py +++ b/src/modules/profiles/data.py @@ -10,6 +10,7 @@ class ProfileData(Registry): CREATE TABLE user_profiles( profileid SERIAL PRIMARY KEY, nickname TEXT, + migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); """ @@ -18,8 +19,10 @@ class ProfileData(Registry): profileid = Integer(primary=True) nickname = String() + migrated = Integer() created_at = Timestamp() + class DiscordProfileRow(RowModel): """ Schema @@ -30,8 +33,8 @@ class ProfileData(Registry): userid BIGINT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); - CREATE UNIQUE INDEX profiles_discord_profileid ON profiles_discord (profileid); - CREATE INDEX profiles_discord_userid ON profiles_discord (userid); + CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid); + CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid); """ _tablename_ = 'profiles_discord' _cache_ = {} @@ -57,8 +60,8 @@ class ProfileData(Registry): userid TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); - CREATE UNIQUE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); - CREATE INDEX profiles_twitch_userid ON profiles_twitch (userid); + CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid); + CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid); """ _tablename_ = 'profiles_twitch' _cache_ = {} @@ -97,7 +100,6 @@ class ProfileData(Registry): communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); - CREATE UNIQUE INDEX communities_discord_communityid ON communities_discord (communityid); """ _tablename_ = 'communities_discord' _cache_ = {} @@ -120,7 +122,6 @@ class ProfileData(Registry): communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE, linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); - CREATE UNIQUE INDEX communities_twitch_communityid ON communities_twitch (communityid); """ _tablename_ = 'communities_twitch' _cache_ = {} diff --git a/src/modules/profiles/profile.py b/src/modules/profiles/profile.py new file mode 100644 index 00000000..9e1151c5 --- /dev/null +++ b/src/modules/profiles/profile.py @@ -0,0 +1,124 @@ +from typing import Optional, Self + +import discord +import twitchio + +from meta import LionBot +from utils.lib import utc_now + +from . import logger +from .data import ProfileData + + + +class UserProfile: + def __init__(self, bot: LionBot, profile_row): + self.bot = bot + self.profile_row: ProfileData.UserProfileRow = profile_row + + @property + def cog(self): + return self.bot.get_cog('ProfileCog') + + @property + def data(self) -> ProfileData: + return self.cog.data + + @property + def profileid(self): + return self.profile_row.profileid + + def __repr__(self): + return f"" + + async def attach_discord(self, user: discord.User | discord.Member): + """ + Attach a new discord user to this profile. + Assumes the discord user does not itself have a profile. + """ + discord_row = await self.data.DiscordProfileRow.create( + profileid=self.profileid, + userid=user.id + ) + logger.info( + f"Attached discord user {user!r} to profile {self!r}" + ) + return discord_row + + async def attach_twitch(self, user: twitchio.User): + """ + Attach a new Twitch user to this profile. + """ + twitch_row = await self.data.TwitchProfileRow.create( + profileid=self.profileid, + userid=str(user.id) + ) + logger.info( + f"Attached twitch user {user!r} to profile {self!r}" + ) + return twitch_row + + async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]: + """ + Fetch the Discord accounts associated to this profile. + """ + return await self.data.DiscordProfileRow.fetch_where(profileid=self.profileid) + + async def twitch_accounts(self) -> list[ProfileData.DiscordProfileRow]: + """ + Fetch the Twitch accounts associated to this profile. + """ + return await self.data.TwitchProfileRow.fetch_where(profileid=self.profileid) + + @classmethod + async def fetch(cls, bot: LionBot, profile_id: int) -> Self: + profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id) + if profile_row is None: + raise ValueError("Provided profile_id does not exist.") + return cls(bot, profile_row) + + @classmethod + async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]: + data = bot.get_cog('ProfileCog').data + rows = await data.TwitchProfileRow.fetch_where(userid=str(userid)) + if rows: + return await cls.fetch(bot, rows[0].profileid) + + @classmethod + async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]: + data = bot.get_cog('ProfileCog').data + rows = await data.DiscordProfileRow.fetch_where(userid=str(userid)) + if rows: + return await cls.fetch(bot, rows[0].profileid) + + @classmethod + async def create(cls, bot: LionBot, **kwargs) -> Self: + """ + Create a new empty profile with the given initial arguments. + + Profiles should usually be created using `create_from_discord` or `create_from_twitch` + to correctly setup initial profile preferences (e.g. name, avatar). + """ + # Create a new profile + data = bot.get_cog('ProfileCog').data + profile_row = await data.UserProfileRow.create(created_at=utc_now()) + profile = await cls.fetch(bot, profile_row.profileid) + return profile + + @classmethod + async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self: + """ + Create a new profile using the given Discord user as a base. + """ + profile = await cls.create(bot, **kwargs) + await profile.attach_discord(user) + return profile + + @classmethod + async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self: + """ + Create a new profile using the given Twitch user as a base. + """ + profile = await cls.create(bot, **kwargs) + await profile.attach_twitch(user) + return profile diff --git a/src/twitch/userflow.py b/src/twitch/userflow.py index ce7d20dc..11c0fef9 100644 --- a/src/twitch/userflow.py +++ b/src/twitch/userflow.py @@ -47,7 +47,7 @@ class UserAuthFlow: self._setup_done.set() return await ws.receive_json() - async def run(self): + async def run(self) -> TwitchAuthData.UserAuthRow: if not self._setup_done.is_set(): raise ValueError("Cannot run UserAuthFlow before setup.") if self._comm_task is None: From 72d52b6014c297552c8bd18f0ad4d56d24028685 Mon Sep 17 00:00:00 2001 From: Interitio Date: Sun, 6 Oct 2024 21:38:09 +1000 Subject: [PATCH 9/9] feat(profiles): Add community profiles. --- data/schema.sql | 1 + src/modules/profiles/cog.py | 212 +++++++++++++++++++++++++++--- src/modules/profiles/community.py | 123 +++++++++++++++++ src/modules/profiles/data.py | 2 + src/modules/profiles/profile.py | 4 +- src/twitch/data.py | 2 +- 6 files changed, 326 insertions(+), 18 deletions(-) diff --git a/data/schema.sql b/data/schema.sql index faf668c8..0ea3aff4 100644 --- a/data/schema.sql +++ b/data/schema.sql @@ -1535,6 +1535,7 @@ CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid); CREATE TABLE communities( communityid SERIAL PRIMARY KEY, + migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); diff --git a/src/modules/profiles/cog.py b/src/modules/profiles/cog.py index 3e0c49a0..08101882 100644 --- a/src/modules/profiles/cog.py +++ b/src/modules/profiles/cog.py @@ -4,7 +4,9 @@ from typing import Optional, overload from datetime import timedelta import discord +from discord import app_commands as appcmds from discord.ext import commands as cmds +from twitchAPI.type import AuthScope import twitchio from twitchio.ext import commands from twitchio import User @@ -18,6 +20,7 @@ from utils.lib import utc_now from . import logger from .data import ProfileData from .profile import UserProfile +from .community import Community class ProfileCog(LionCog): @@ -89,7 +92,7 @@ class ProfileCog(LionCog): results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).") # And then mark the old profile as migrated - await source_profile.update(migrate=target_profile.profileid) + await source_profile.update(migrated=target_profile.profileid) results.append("Marking old profile as migrated.. finished!") return results @@ -97,37 +100,107 @@ class ProfileCog(LionCog): """ Fetch a UserProfile by the given id. """ - return await UserProfile.fetch_profile(self.bot, profile_id=profile_id) + return await UserProfile.fetch(self.bot, profile_id=profile_id) async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile: """ Fetch or create a UserProfile from the provided discord account. """ - profile = await UserProfile.fetch_from_discordid(user.id) + profile = await UserProfile.fetch_from_discordid(self.bot, user.id) if profile is None: - profile = await UserProfile.create_from_discord(user) + profile = await UserProfile.create_from_discord(self.bot, user) return profile async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile: """ Fetch or create a UserProfile from the provided twitch account. """ - profile = await UserProfile.fetch_from_twitchid(user.id) + profile = await UserProfile.fetch_from_twitchid(self.bot, user.id) if profile is None: - profile = await UserProfile.create_from_twitch(user) + profile = await UserProfile.create_from_twitch(self.bot, user) return profile - async def fetch_community_discord(self, guildid: int, create=True): - ... + # Community API + def add_community_migrator(self, migrator, name=None): + name = name or migrator.__name__ + self._comm_migrators[name or migrator.__name__] = migrator - async def fetch_community_twitch(self, guildid: int, create=True): - ... + logger.info( + f"Added community migrator {name}: {migrator}" + ) + return migrator - async def fetch_community(self, communityid: int): - ... + def del_community_migrator(self, name: str): + migrator = self._comm_migrators.pop(name, None) + + logger.info( + f"Removed community migrator {name}: {migrator}" + ) + + @log_wrap(action="community migration") + async def migrate_community(self, source_comm, target_comm) -> list[str]: + logger.info( + f"Beginning community migration from {source_comm!r} to {target_comm!r}" + ) + results = [] + # Wrap this in a transaction so if something goes wrong with migration, + # we roll back safely (although this may mess up caches) + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + for name, migrator in self._comm_migrators.items(): + try: + result = await migrator(source_comm, target_comm) + if result: + results.append(result) + except Exception: + logger.exception( + f"Unexpected exception running community migrator {name} " + f"migrating {source_comm!r} to {target_comm!r}." + ) + raise + + # Move all Discord and Twitch community preferences over to the new profile + discord_rows = await self.data.DiscordCommunityRow.table.update_where( + profileid=source_comm.communityid + ).set(communityid=target_comm.communityid) + results.append(f"Migrated {len(discord_rows)} attached discord guilds.") + + twitch_rows = await self.data.TwitchCommunityRow.table.update_where( + communityid=source_comm.communityid + ).set(communityid=target_comm.communityid) + results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).") + + # And then mark the old community as migrated + await source_comm.update(migrated=target_comm.communityid) + results.append("Marking old community as migrated.. finished!") + return results + + async def fetch_community_by_id(self, community_id: int) -> Community: + """ + Fetch a Community by the given id. + """ + return await Community.fetch(self.bot, community_id=community_id) + + async def fetch_community_discord(self, guild: discord.Guild) -> Community: + """ + Fetch or create a Community from the provided discord guild. + """ + comm = await Community.fetch_from_discordid(self.bot, guild.id) + if comm is None: + comm = await Community.create_from_discord(self.bot, guild) + return comm + + async def fetch_community_twitch(self, user: twitchio.User) -> Community: + """ + Fetch or create a Community from the provided twitch account. + """ + community = await Community.fetch_from_twitchid(self.bot, user.id) + if community is None: + community = await Community.create_from_twitch(self.bot, user) + return community # ----- Profile Commands ----- - @cmds.hybrid_group( name='profiles', description="Base comand group for user profiles." @@ -170,6 +243,7 @@ class ProfileCog(LionCog): f"User {authrow} obtained from Twitch authentication does not exist." ) await ctx.error_reply("Sorry, something went wrong. Please try again later!") + return user = results[0] @@ -189,7 +263,7 @@ class ProfileCog(LionCog): # Attach the discord row to the profile await source_profile.attach_discord(ctx.author) await message.edit( - content=f"Successfully connect to Twitch profile **{user.name}**! There was no profile data to merge." + content=f"Successfully connected to Twitch profile **{user.name}**! There was no profile data to merge." ) elif source_profile is None and author_profile is None: profile = await UserProfile.create_from_discord(self.bot, ctx.author) @@ -217,6 +291,114 @@ class ProfileCog(LionCog): content = '\n'.join(( "## Connecting Twitch account and merging profiles...", *results, - "**Successfully linked account and merge profile data!**" + "**Successfully linked account and merged profile data!**" + )) + await message.edit(content=content) + + # ----- Community Commands ----- + @cmds.hybrid_group( + name='community', + description="Base comand group for community profiles." + ) + async def community_grp(self, ctx: LionContext): + ... + + @community_grp.group( + name='link', + description="Base command group for linking communities" + ) + async def community_link_grp(self, ctx: LionContext): + ... + + @community_link_grp.command( + name='twitch', + description="Link a twitch account to this community." + ) + @appcmds.guild_only() + @appcmds.default_permissions(manage_guild=True) + async def comm_link_twitch_cmd(self, ctx: LionContext): + if not ctx.interaction: + return + assert ctx.guild is not None + + await ctx.interaction.response.defer(ephemeral=True) + + if not ctx.author.guild_permissions.manage_guild: + await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.") + return + + # Ask the user to go through auth to get their userid + auth_cog = self.bot.get_cog('TwitchAuthCog') + flow = await auth_cog.start_auth( + scopes=[ + AuthScope.CHAT_EDIT, + AuthScope.CHAT_READ, + AuthScope.MODERATION_READ, + AuthScope.CHANNEL_BOT, + ] + ) + message = await ctx.reply( + f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server." + ) + authrow = await flow.run() + await message.edit( + content="Authentication Complete! Beginning community profile merge..." + ) + + results = await self.crocbot.fetch_users(ids=[authrow.userid]) + if not results: + logger.error( + f"User {authrow} obtained from Twitch authentication does not exist." + ) + await ctx.error_reply("Sorry, something went wrong. Please try again later!") + return + + user = results[0] + + # Retrieve author's profile if it exists + guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id) + + # Check if the twitch-side user has a profile + twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id) + + if guild_comm and twitch_comm is None: + # All we need to do is attach the twitch row + await guild_comm.attach_twitch(user) + await message.edit( + content=f"Successfully linked Twitch channel **{user.name}**! There was no community data to merge." + ) + elif twitch_comm and guild_comm is None: + # Attach the discord row to the profile + await twitch_comm.attach_discord(ctx.guild) + await message.edit( + content=f"Successfully connected to Twitch channel **{user.name}**!" + ) + elif twitch_comm is None and guild_comm is None: + profile = await Community.create_from_discord(self.bot, ctx.guild) + await profile.attach_twitch(user) + + await message.edit( + content=f"Created a new community for this server and linked Twitch account **{user.name}**." + ) + elif guild_comm.communityid == twitch_comm.communityid: + await message.edit( + content=f"This server is already linked to the Twitch channel **{user.name}**!" + ) + else: + # Migrate the existing profile data to the new profiles + try: + results = await self.migrate_community(twitch_comm, guild_comm) + except Exception: + await ctx.error_reply( + "An issue was encountered while merging your community profiles!\n" + "Migration rolled back, no data has been lost.\n" + "The developer has been notified. Please try again later!" + ) + raise + + content = '\n'.join(( + "## Connecting Twitch account and merging community profiles...", + *results, + "**Successfully linked account and merged community data!**" )) await message.edit(content=content) diff --git a/src/modules/profiles/community.py b/src/modules/profiles/community.py index e69de29b..4e7844d9 100644 --- a/src/modules/profiles/community.py +++ b/src/modules/profiles/community.py @@ -0,0 +1,123 @@ +from typing import Optional, Self + +import discord +import twitchio + +from meta import LionBot +from utils.lib import utc_now + +from . import logger +from .data import ProfileData + + + +class Community: + def __init__(self, bot: LionBot, community_row): + self.bot = bot + self.row: ProfileData.CommunityRow = community_row + + @property + def cog(self): + return self.bot.get_cog('ProfileCog') + + @property + def data(self) -> ProfileData: + return self.cog.data + + @property + def communityid(self): + return self.row.communityid + + def __repr__(self): + return f"" + + async def attach_discord(self, guild: discord.Guild): + """ + Attach a new discord guild to this community. + Assumes the discord guild is not already associated to a community. + """ + discord_row = await self.data.DiscordCommunityRow.create( + communityid=self.communityid, + guildid=guild.id + ) + logger.info( + f"Attached discord guild {guild!r} to community {self!r}" + ) + return discord_row + + async def attach_twitch(self, user: twitchio.User): + """ + Attach a new Twitch user channel to this community. + """ + twitch_row = await self.data.TwitchCommunityRow.create( + communityid=self.communityid, + channelid=str(user.id) + ) + logger.info( + f"Attached twitch channel {user!r} to community {self!r}" + ) + return twitch_row + + async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]: + """ + Fetch the Discord guild rows associated to this community. + """ + return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid) + + async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]: + """ + Fetch the Twitch user rows associated to this profile. + """ + return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid) + + @classmethod + async def fetch(cls, bot: LionBot, community_id: int) -> Self: + community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id) + if community_row is None: + raise ValueError("Provided community_id does not exist.") + return cls(bot, community_row) + + @classmethod + async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]: + data = bot.get_cog('ProfileCog').data + rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid)) + if rows: + return await cls.fetch(bot, rows[0].communityid) + + @classmethod + async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]: + data = bot.get_cog('ProfileCog').data + rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid) + if rows: + return await cls.fetch(bot, rows[0].communityid) + + @classmethod + async def create(cls, bot: LionBot, **kwargs) -> Self: + """ + Create a new empty community with the given initial arguments. + + Communities should usually be created using `create_from_discord` or `create_from_twitch` + to correctly setup initial preferences (e.g. name, avatar). + """ + # Create a new community + data = bot.get_cog('ProfileCog').data + row = await data.CommunityRow.create(created_at=utc_now(), **kwargs) + return await cls.fetch(bot, row.communityid) + + @classmethod + async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self: + """ + Create a new community using the given Discord guild as a base. + """ + self = await cls.create(bot, **kwargs) + await self.attach_discord(guild) + return self + + @classmethod + async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self: + """ + Create a new profile using the given Twitch channel user as a base. + """ + self = await cls.create(bot, **kwargs) + await self.attach_twitch(user) + return self diff --git a/src/modules/profiles/data.py b/src/modules/profiles/data.py index eed48792..f3e764c8 100644 --- a/src/modules/profiles/data.py +++ b/src/modules/profiles/data.py @@ -82,6 +82,7 @@ class ProfileData(Registry): ------ CREATE TABLE communities( communityid SERIAL PRIMARY KEY, + migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); """ @@ -89,6 +90,7 @@ class ProfileData(Registry): _cache_ = {} communityid = Integer(primary=True) + migrated = Integer() created_at = Timestamp() class DiscordCommunityRow(RowModel): diff --git a/src/modules/profiles/profile.py b/src/modules/profiles/profile.py index 9e1151c5..aaf66a96 100644 --- a/src/modules/profiles/profile.py +++ b/src/modules/profiles/profile.py @@ -64,7 +64,7 @@ class UserProfile: """ return await self.data.DiscordProfileRow.fetch_where(profileid=self.profileid) - async def twitch_accounts(self) -> list[ProfileData.DiscordProfileRow]: + async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]: """ Fetch the Twitch accounts associated to this profile. """ @@ -87,7 +87,7 @@ class UserProfile: @classmethod async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]: data = bot.get_cog('ProfileCog').data - rows = await data.DiscordProfileRow.fetch_where(userid=str(userid)) + rows = await data.DiscordProfileRow.fetch_where(userid=(userid)) if rows: return await cls.fetch(bot, rows[0].profileid) diff --git a/src/twitch/data.py b/src/twitch/data.py index 94b45c37..ab3459c3 100644 --- a/src/twitch/data.py +++ b/src/twitch/data.py @@ -76,4 +76,4 @@ class TwitchAuthData(Registry): ); CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid); """ - user_scopes = Table('twitch_token_scopes') + user_scopes = Table('twitch_user_scopes')