Compare commits

...

25 Commits

Author SHA1 Message Date
445eccccd6 feat(dreams): Add basic document viewer. 2025-06-14 03:26:41 +10:00
6ec500ec87 Add missing CrocBot 2025-06-13 23:57:43 +10:00
8e2bd67efc Skeleton module structure. 2025-06-13 23:50:28 +10:00
092e818990 api: Add event logging via webhook. 2025-06-13 23:47:51 +10:00
2bf95beaae (profiles): Improve logging and error handling. 2025-06-12 23:35:29 +10:00
250b55634d cleanup(api): Move route registration to routes. 2025-06-12 23:11:34 +10:00
9e5c2f5777 feat(profiles): Add profile link UI. 2025-06-12 23:10:14 +10:00
e1a1f7d4fe fix: Add CASCADE for event types on deletion. 2025-06-11 19:35:39 +10:00
c3ed48e918 fix: Properly delete UserProfile. 2025-06-11 19:33:08 +10:00
48a01a2861 fix: Typo in raid event visitor_count field. 2025-06-11 19:20:34 +10:00
d83709d2c2 fix: Remove enum contains for 3.11 compat. 2025-06-10 23:20:13 +10:00
7f977f90e8 cleanup: Remove MetaModel. 2025-06-10 22:43:08 +10:00
04b6dcbc3f fix: run_app usage 2025-06-10 22:42:51 +10:00
c07577cc0a Support setting port in config. 2025-06-10 22:27:10 +10:00
dc551b34a9 feat(api): Finished initial route collection. 2025-06-10 13:00:37 +10:00
94bc8b6c21 (api): Add document routes. 2025-06-08 22:05:13 +10:00
aba73b8bba dix(data): Fix typo in schema. 2025-06-08 22:04:42 +10:00
77dc90cc32 feat(api): Initial API server and stamps routes. 2025-06-07 05:29:00 +10:00
a02cc0977a (document): Add created ts and file format. 2025-06-07 05:27:59 +10:00
5efcdd6709 (data): Schema and object model. 2025-06-06 22:27:57 +10:00
0adccaae02 Load twitch and profile modules 2025-06-06 22:25:53 +10:00
a7afa5001d fix(data) Fix for new pgsql lib version. 2025-06-06 22:25:05 +10:00
d271248812 fix(meta): Repair issues from dpy update. 2025-06-06 22:24:28 +10:00
8421c5359d (WIP) Add user profile module. 2025-06-06 00:05:41 +10:00
2cf81c38e8 Add twitch auth module. 2025-06-06 00:05:24 +10:00
42 changed files with 5087 additions and 16 deletions

View File

@@ -189,9 +189,10 @@ CREATE TABLE stamp_types (
CREATE TABLE documents (
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
document_data VARCHAR NOT NULL,
document_data TEXT NOT NULL,
seal INTEGER NOT NULL,
metadata TEXT
metadata TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE document_stamps (
@@ -230,14 +231,14 @@ CREATE TABLE plain_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
message TEXT NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE raid_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
visitor_count INTEGER NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE cheer_events (
@@ -246,7 +247,7 @@ CREATE TABLE cheer_events (
amount INTEGER NOT NULL,
cheer_type TEXT,
message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE subscriber_events (
@@ -255,10 +256,37 @@ CREATE TABLE subscriber_events (
subscribed_length INTEGER NOT NULL,
tier INTEGER NOT NULL,
message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE VIEW event_details AS
SELECT
events.event_id AS event_id,
events.user_id AS user_id,
events.document_id AS document_id,
events.user_name AS user_name,
events.event_type AS event_type,
events.occurred_at AS occurred_at,
events.created_at AS created_at,
plain_events.message AS plain_message,
raid_events.visitor_count AS raid_visitor_count,
cheer_events.amount AS cheer_amount,
cheer_events.cheer_type AS cheer_type,
cheer_events.message AS cheer_message,
subscriber_events.subscribed_length AS subscriber_length,
subscriber_events.tier AS subscriber_tier,
subscriber_events.message AS subscriber_message,
documents.seal AS document_seal
FROM
events
LEFT JOIN plain_events USING (event_id)
LEFT JOIN raid_events USING (event_id)
LEFT JOIN cheer_events USING (event_id)
LEFT JOIN subscriber_events USING (event_id)
LEFT JOIN documents USING (document_id)
ORDER BY events.occurred_at ASC;
-- }}}
-- Specimens {{{
@@ -269,6 +297,7 @@ CREATE TABLE user_specimens (
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
forgotten_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
-- }}}

View File

@@ -4,3 +4,5 @@ discord.py [voice]
iso8601
psycopg[pool]
pytz
twitchio
twitchAPI

61
src/api.py Normal file
View File

@@ -0,0 +1,61 @@
import sys
import os
import asyncio
from aiohttp import web
import logging
from meta import conf
from data import Database
from utils.auth import key_auth_factory
from datamodels import DataModel
from constants import DATA_VERSION
from modules.profiles.data import ProfileData
from routes import dbvar, datamodelsv, profiledatav, register_routes
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
logger = logging.getLogger(__name__)
async def attach_db(app: web.Application):
db = Database(conf.data['args'])
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
datamodel = DataModel()
db.load_registry(datamodel)
await datamodel.init()
profiledata = ProfileData()
db.load_registry(profiledata)
await profiledata.init()
app[dbvar] = db
app[datamodelsv] = datamodel
app[profiledatav] = profiledata
yield
async def test(request: web.Request) -> web.Response:
return web.Response(text="Welcome to the Dreamspace API. Please donate an important childhood memory to continue.")
def app_factory():
auth = key_auth_factory(conf.API['TOKEN'])
app = web.Application(middlewares=[auth])
app.cleanup_ctx.append(attach_db)
app.router.add_get('/', test)
register_routes(app.router)
return app
if __name__ == '__main__':
app = app_factory()
web.run_app(app, port=int(conf.API['PORT']))

View File

@@ -4,6 +4,7 @@ import logging
import aiohttp
import discord
from discord.ext import commands
from twitchAPI.twitch import Twitch
from meta import LionBot, conf, sharding, appname
from meta.app import shardname
@@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus:
async def main():
log_action_stack.set(("Initialising",))
logger.info("Initialising StudyLion")
logger.info("Initialising LionBot")
intents = discord.Intents.all()
intents.members = True
intents.message_content = True
intents.presences = False
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
@@ -73,6 +76,7 @@ async def main():
config=conf,
initial_extensions=[
'core',
'twitch',
'modules',
],
web_client=session,
@@ -82,6 +86,7 @@ async def main():
help_command=None,
proxy=conf.bot.get('proxy', None),
chunk_guilds_at_startup=False,
twitch=twitch
) as lionbot:
ctx_bot.set(lionbot)
lionbot.system_monitor.add_component(
@@ -89,11 +94,11 @@ async def main():
)
try:
log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
logger.info("LionBot initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError:
log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
def _main():

6
src/brand.py Normal file
View File

@@ -0,0 +1,6 @@
import discord
# Theme
MAIN_COLOUR = discord.Colour.from_str('#11EA11')
ACCENT_COLOUR = discord.Colour.from_str('#EA11EA')

View File

@@ -11,6 +11,7 @@ from meta.app import shardname, appname
from meta.logger import log_wrap
from utils.lib import utc_now
from datamodels import DataModel
from .data import CoreData
logger = logging.getLogger(__name__)
@@ -29,7 +30,9 @@ class CoreCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = CoreData()
self.datamodel = DataModel()
bot.db.load_registry(self.data)
bot.db.load_registry(self.datamodel)
self.app_config: Optional[CoreData.AppConfig] = None
self.bot_config: Optional[CoreData.BotConfig] = None
@@ -43,6 +46,9 @@ class CoreCog(LionCog):
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
await self.data.init()
await self.datamodel.init()
# Load the app command cache
await self.reload_appcmd_cache()

View File

@@ -47,8 +47,8 @@ class Connector:
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
min_size=1,
max_size=4,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)

View File

@@ -3,7 +3,7 @@ from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import pgconn_encoding
from psycopg._encodings import conn_encoding
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ class AsyncLoggingCursor(AsyncCursor):
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg

337
src/datamodels.py Normal file
View File

@@ -0,0 +1,337 @@
from io import BytesIO
import base64
from enum import Enum
from typing import NamedTuple
from data import Registry, RowModel, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Column
class EventType(Enum):
SUBSCRIBER = 'subscriber',
RAID = 'raid',
CHEER = 'cheer',
PLAIN = 'plain',
def info(self):
if self is EventType.SUBSCRIBER:
info = EventTypeInfo(
EventType.SUBSCRIBER,
DataModel.subscriber_events,
("tier", "subscribed_length", "message"),
("tier", "subscribed_length", "message"),
('subscriber_tier', 'subscriber_length', 'subscriber_message'),
)
elif self is EventType.RAID:
info = EventTypeInfo(
EventType.RAID,
DataModel.raid_events,
('visitor_count',),
('viewer_count',),
('raid_visitor_count',),
)
elif self is EventType.CHEER:
info = EventTypeInfo(
EventType.CHEER,
DataModel.cheer_events,
('amount', 'cheer_type', 'message'),
('amount', 'cheer_type', 'message'),
('cheer_amount', 'cheer_type', 'cheer_message'),
)
elif self is EventType.PLAIN:
info = EventTypeInfo(
EventType.PLAIN,
DataModel.plain_events,
('message',),
('message',),
('plain_message',),
)
else:
raise ValueError("Unexpected event type.")
return info
class EventTypeInfo(NamedTuple):
typ: EventType
table: Table
columns: tuple[str, ...]
params: tuple[str, ...]
detailcolumns: tuple[str, ...]
class DataModel(Registry):
_EventType = RegisterEnum(EventType, 'EventType')
class UserPreferences(RowModel):
"""
Schema
------
CREATE TABLE user_preferences (
profileid INTEGER PRIMARY KEY REFERENCES user_profiles (profileid) ON DELETE CASCADE,
twitch_name TEXT,
preferences TEXT
);
"""
_tablename_ = 'user_preferences'
_cache_ = {}
profileid = Integer(primary=True)
twitch_name = String()
preferences = String()
class Dreamer(RowModel):
"""
Schema
------
CREATE VIEW dreamers AS
SELECT
user_profiles.profileid AS user_id,
user_preferences.twitch_name AS name,
profiles_twitch.userid AS twitch_id,
user_preferences.preferences AS preferences,
user_profiles.created_at AS created_at
FROM
user_profiles
LEFT JOIN profiles_twitch USING (profileid)
LEFT JOIN user_preferences USING (profileid);
"""
_tablename_ = 'dreamers'
_readonly_ = True
user_id = Integer(primary=True)
name = String()
twitch_id = Integer()
preferences = String()
created_at = Timestamp()
class Transaction(RowModel):
"""
Schema
------
CREATE TABLE user_wallet (
transaction_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
amount INTEGER NOT NULL,
description TEXT NOT NULL,
reference TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_wallet'
_cache_ = {}
_immutable_ = True
transaction_id = Integer(primary=True)
user_id = Integer()
amount = Integer()
description = String()
reference = String()
created_at = Timestamp()
class StampType(RowModel):
"""
Schema
------
CREATE TABLE stamp_types (
stamp_type_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
stamp_type_name TEXT UNIQUE NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'stamp_types'
_cache_ = {}
stamp_type_id = Integer(primary=True)
stamp_type_name = String()
created_at = Timestamp()
class Document(RowModel):
"""
Schema
------
CREATE TABLE documents (
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
document_data TEXT NOT NULL,
seal INTEGER NOT NULL,
metadata TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'documents'
_cache_ = {}
document_id = Integer(primary=True)
document_data = Column()
seal = Integer()
metadata = String()
created_at = Timestamp()
def to_bytes(self):
"""
Helper method to decode the saved document data to a byte string.
This may fail if the saved string is not base64 encoded.
"""
byts = BytesIO(base64.b64decode(self.document_data))
return byts
class DocumentStamp(RowModel):
"""
Schema
------
CREATE TABLE document_stamps (
stamp_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
document_id INTEGER NOT NULL REFERENCES documents (document_id) ON DELETE CASCADE,
stamp_type INTEGER NOT NULL REFERENCES stamp_types (stamp_type_id) ON DELETE CASCADE,
position_x INTEGER NOT NULL,
position_y INTEGER NOT NULL,
rotation REAL NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'document_stamps'
_cache_ = {}
stamp_id = Integer(primary=True)
document_id = Integer()
stamp_type = Integer()
position_x = Integer()
position_y = Integer()
rotation: Column[float] = Column()
created_at = Timestamp()
class Events(RowModel):
"""
Schema
------
CREATE TABLE events (
event_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
document_id INTEGER REFERENCES documents (document_id) ON DELETE SET NULL,
user_name TEXT,
event_type EventType NOT NULL,
occurred_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (event_id, event_type)
);
"""
_tablename_ = 'events'
_cache_ = {}
event_id = Integer(primary=True)
user_id = Integer()
document_id = Integer()
user_name = String()
event_type: Column[EventType] = Column()
occured_at = Timestamp()
created_at = Timestamp()
plain_events = Table('plain_events')
raid_events = Table('raid_events')
cheer_events = Table('cheer_events')
subscriber_events = Table('subscriber_events')
class EventDetails(RowModel):
"""
Schema
------
CREATE TABLE plain_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
message TEXT NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE raid_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
visitor_count INTEGER NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE cheer_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'cheer' CHECK (event_type = 'cheer'),
amount INTEGER NOT NULL,
cheer_type TEXT,
message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE TABLE subscriber_events (
event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'subscriber' CHECK (event_type = 'subscriber'),
subscribed_length INTEGER NOT NULL,
tier INTEGER NOT NULL,
message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
);
CREATE VIEW event_details AS
SELECT
events.event_id AS event_id,
events.user_id AS user_id,
events.document_id AS document_id,
events.user_name AS user_name,
events.event_type AS event_type,
events.occurred_at AS occurred_at,
events.created_at AS created_at,
plain_events.message AS plain_message,
raid_events.visitor_count AS raid_visitor_count,
cheer_events.amount AS cheer_amount,
cheer_events.cheer_type AS cheer_type,
cheer_events.message AS cheer_message,
subscriber_events.subscribed_length AS subscriber_length,
subscriber_events.tier AS subscriber_tier,
subscriber_events.message AS subscriber_message,
documents.seal AS document_seal
FROM
events
LEFT JOIN plain_events USING (event_id)
LEFT JOIN raid_events USING (event_id)
LEFT JOIN cheer_events USING (event_id)
LEFT JOIN subscriber_events USING (event_id)
LEFT JOIN documents USING (document_id)
ORDER BY events.occurred_at ASC;
"""
_tablename_ = 'event_details'
_readonly_ = True
event_id = Integer(primary=True)
user_id = Integer()
document_id = Integer()
user_name = String()
event_type: Column[EventType] = Column()
occurred_at = Timestamp()
created_at = Timestamp()
plain_message = String()
raid_visitor_count = Integer()
cheer_amount = Integer()
cheer_type = String()
cheer_message = String()
subscriber_length = Integer()
subscriber_tier = Integer()
subscriber_message = String()
document_seal = Integer()
class Specimen(RowModel):
"""
Schema
------
CREATE TABLE user_specimens (
specimen_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
owner_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
forgotten_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
"""
_tablename_ = 'user_specimens'
_cache_ = {}
specimen_id = Integer(primary=True)
owner_id = Integer(primary=True)
born_at = Timestamp()
forgotten_at = Timestamp()

76
src/meta/CrocBot.py Normal file
View File

@@ -0,0 +1,76 @@
from collections import defaultdict
from typing import TYPE_CHECKING
import logging
import twitchio
from twitchio.ext import commands
from twitchio.ext import pubsub
from twitchio.ext.commands.core import itertools
from data import Database
from .config import Conf
logger = logging.getLogger(__name__)
class CrocBot(commands.Bot):
def __init__(self, *args,
config: Conf,
data: Database,
**kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data = data
self.pubsub = pubsub.PubSubPool(self)
self._member_cache = defaultdict(dict)
async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
self._member_cache[channel.name][user.name] = user
async def event_message(self, message: twitchio.Message):
if message.channel and message.author:
self._member_cache[message.channel.name][message.author.name] = message.author
await self.handle_commands(message)
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
if userstr.startswith('@'):
matching = False
userstr = userstr.strip('@ ')
result = None
if matching and len(userstr) >= 3:
lowered = userstr.lower()
full_matches = []
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
matchstr = user.name.lower()
print(matchstr)
if matchstr.startswith(lowered):
result = user
break
if lowered in matchstr:
full_matches.append(user)
if result is None and full_matches:
result = full_matches[0]
print(result)
if result is None:
lookup = userstr
elif result.id is None:
lookup = result.name
else:
lookup = None
if lookup:
found = await self.fetch_users(names=[lookup])
if found:
result = found[0]
# No matches found
return result

View File

@@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession
from twitchAPI.twitch import Twitch
from data import Database
from utils.lib import tabulate
@@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat
if TYPE_CHECKING:
from core.cog import CoreCog
from twitch.cog import TwitchAuthCog
from modules.profiles.cog import ProfileCog
logger = logging.getLogger(__name__)
@@ -31,7 +34,9 @@ class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession,
testing_guilds: List[int] = [], **kwargs
twitch: Twitch,
testing_guilds: List[int] = [],
**kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
@@ -43,6 +48,7 @@ class LionBot(Bot):
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.twitch = twitch
self.system_monitor = SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
@@ -101,6 +107,14 @@ class LionBot(Bot):
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
...
@overload
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
...
@overload
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
...
@overload
def get_cog(self, name: str) -> Optional[Cog]:
...
@@ -189,7 +203,7 @@ class LionBot(Bot):
# TODO: Some of these could have more user-feedback
logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
cmd_str = ctx.command.app_command.to_dict(self.tree)
else:
cmd_str = str(ctx.command)
try:

View File

@@ -133,7 +133,7 @@ class LionTree(CommandTree):
return
set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:

View File

@@ -1,7 +1,9 @@
this_package = 'modules'
active = [
'.profiles',
'.sysadmin',
'.dreamspace',
]

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import DreamCog
await bot.add_cog(DreamCog(bot))

View File

@@ -0,0 +1,70 @@
import asyncio
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from meta import LionCog, LionBot, LionContext
from meta.logger import log_wrap
from utils.lib import utc_now
from . import logger
from .ui.docviewer import DocumentViewer
class DreamCog(LionCog):
"""
Discord-facting interface for Dreamspace Adventures
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.core.datamodel
async def cog_load(self):
pass
@log_wrap(action="Dreamer migration")
async def migrate_dreamer(self, source_profile, target_profile):
"""
Called when two dreamer profiles need to merge.
For example, when a user links a second twitch profile.
:TODO-MARKER:
Most of the migration logic is simple, e.g. just update the profileid
on the old events to the new profile.
The same applies to transactions and probably to inventory items.
However, there are some subtle choices to make, such as what to do
if both the old and the new profile have an active specimen?
A profile can only have one active specimen at a time.
There is also the question of how to merge user preferences, when those exist.
"""
...
# User command: view their dreamer card, wallet inventory etc
# (Admin): View events/documents matching certain criteria
# (User): View own event cards with info?
# Let's make a demo viewer which lists their event cards and let's them open one via select?
# /documents -> Show a paged list of documents, select option displays the document in a viewer
@cmds.hybrid_command(
name='documents',
description="View your printer log!"
)
async def documents_cmd(self, ctx: LionContext):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
events = await self.data.Events.fetch_where(user_id=profile.profileid)
docids = [event.document_id for event in events if event.document_id is not None]
if not docids:
await ctx.error_reply("You don't have any documents yet!")
return
view = DocumentViewer(self.bot, ctx.interaction.user.id, filter=(self.data.Document.document_id == docids))
await view.run(ctx.interaction)
# (User): View Specimen information

View File

@@ -0,0 +1,36 @@
from datamodels import DataModel
class Event:
_typs = {}
def __init__(self, event_row: DataModel.Events, **kwargs):
self.row = event_row
def __getattribute__(self, name: str):
...
async def get_document(self):
...
async def get_user(self):
...
class Document:
def as_bytes(self):
...
async def get_stamps(self):
...
async def refresh(self):
...
class User:
...
class Stamp:
...

View File

@@ -0,0 +1,171 @@
import binascii
from itertools import chain
from typing import Optional
from dataclasses import dataclass
import asyncio
import datetime as dt
import discord
from discord.ui.select import select, Select, UserSelect
from discord.ui.button import button, Button
from discord.ui.text_input import TextInput
from discord.enums import ButtonStyle, TextStyle
from discord.components import SelectOption
from datamodels import DataModel
from meta import LionBot, conf
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from data import ORDER, Condition
from utils.ui import MessageUI, input
from utils.lib import MessageArgs, tabulate, utc_now
from .. import logger
class DocumentViewer(MessageUI):
"""
Simple pager which displays a filtered list of Documents.
"""
block_len = 5
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.data: DataModel = bot.core.datamodel
self.filter = filter
# Paging state
self._pagen = 0
self.blocks = [[]]
@property
def page_count(self):
return len(self.blocks)
@property
def pagen(self):
self._pagen %= self.page_count
return self._pagen
@pagen.setter
def pagen(self, value):
self._pagen = value % self.page_count
@property
def current_page(self):
return self.blocks[self.pagen]
# ----- UI Components -----
# Page backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
async def jump_button(self, press: discord.Interaction, pressed: Button):
"""
Jump to page button.
"""
try:
interaction, value = await input(
press,
title="Jump to page",
question="Page number to jump to"
)
value = value.strip()
except asyncio.TimeoutError:
return
if not value.lstrip('- ').isdigit():
error = discord.Embed(title="Invalid page number, please try again!",
colour=discord.Colour.brand_red())
await interaction.response.send_message(embed=error, ephemeral=True)
else:
await interaction.response.defer(thinking=True)
pagen = int(value.lstrip('- '))
if value.startswith('-'):
pagen = -1 * pagen
elif pagen > 0:
pagen = pagen - 1
self.pagen = pagen
await self.refresh(thinking=interaction)
async def jump_button_refresh(self):
component = self.jump_button
component.label = f"{self.pagen + 1}/{self.page_count}"
component.disabled = (self.page_count <= 1)
# Page forwards
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
# if self.child_viewer:
# await self.child_viewer.quit()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
files = []
embeds = []
for doc in self.current_page:
try:
imagedata = doc.to_bytes()
imagedata.seek(0)
except binascii.Error:
continue
fn = f"doc-{doc.document_id}.png"
file = discord.File(imagedata, fn)
embed = discord.Embed()
embed.set_image(url=f"attachment://{fn}")
files.append(file)
embeds.append(embed)
if not embeds:
embed = discord.Embed(description="You don't have any documents yet!")
embeds.append(embed)
print(f"FILES: {files}")
return MessageArgs(files=files, embeds=embeds)
async def refresh_layout(self):
to_refresh = (
self.jump_button_refresh(),
)
await asyncio.gather(*to_refresh)
if self.page_count > 1:
page_line = (
self.prev_button,
self.jump_button,
self.quit_button,
self.next_button,
)
else:
page_line = (self.quit_button,)
self.set_layout(page_line)
async def reload(self):
docs = await self.data.Document.fetch_where(self.filter).order_by('created_at', ORDER.DESC)
blocks = [
docs[i:i+self.block_len]
for i in range(0, len(docs), self.block_len)
]
self.blocks = blocks or [[]]

View File

@@ -0,0 +1,406 @@
from itertools import chain
from typing import Optional
from dataclasses import dataclass
import asyncio
import datetime as dt
import discord
from discord.ui.select import select, Select, UserSelect
from discord.ui.button import button, Button
from discord.ui.text_input import TextInput
from discord.enums import ButtonStyle, TextStyle
from discord.components import SelectOption
from datamodels import DataModel
from meta import LionBot, conf
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from data import ORDER, Condition
from utils.ui import MessageUI, input
from utils.lib import MessageArgs, tabulate, utc_now
from .. import logger
class EventsUI(MessageUI):
block_len = 10
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.data: DataModel = bot.core.datamodel
self.filter = Condition
# Paging state
self._pagen = 0
self.blocks = [[]]
@property
def page_count(self):
return len(self.blocks)
@property
def pagen(self):
self._pagen %= self.page_count
return self._pagen
@pagen.setter
def pagen(self, value):
self._pagen = value % self.page_count
@property
def current_page(self):
return self.blocks[self.pagen]
# ----- UI Components -----
# Page backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
# Page forwards
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
# if self.child_viewer:
# await self.child_viewer.quit()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
...
async def refresh_layout(self):
...
async def reload(self):
...
class TicketListUI(MessageUI):
# Select Ticket
@select(
cls=Select,
placeholder="TICKETS_MENU_PLACEHOLDER",
min_values=1, max_values=1
)
async def tickets_menu(self, selection: discord.Interaction, selected: Select):
await selection.response.defer(thinking=True, ephemeral=True)
if selected.values:
ticketid = int(selected.values[0])
ticket = await Ticket.fetch_ticket(self.bot, ticketid)
ticketui = TicketUI(self.bot, ticket, self._callerid)
if self.child_ticket:
await self.child_ticket.quit()
self.child_ticket = ticketui
await ticketui.run(selection)
async def tickets_menu_refresh(self):
menu = self.tickets_menu
t = self.bot.translator.t
menu.placeholder = t(_p(
'ui:tickets|menu:tickets|placeholder',
"Select Ticket"
))
options = []
for ticket in self.current_page:
option = SelectOption(
label=f"Ticket #{ticket.data.guild_ticketid}",
value=str(ticket.data.ticketid)
)
options.append(option)
menu.options = options
# Backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True, ephemeral=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
async def jump_button(self, press: discord.Interaction, pressed: Button):
"""
Jump-to-page button.
Loads a page-switch dialogue.
"""
t = self.bot.translator.t
try:
interaction, value = await input(
press,
title=t(_p(
'ui:tickets|button:jump|input:title',
"Jump to page"
)),
question=t(_p(
'ui:tickets|button:jump|input:question',
"Page number to jump to"
))
)
value = value.strip()
except asyncio.TimeoutError:
return
if not value.lstrip('- ').isdigit():
error_embed = discord.Embed(
title=t(_p(
'ui:tickets|button:jump|error:invalid_page',
"Invalid page number, please try again!"
)),
colour=discord.Colour.brand_red()
)
await interaction.response.send_message(embed=error_embed, ephemeral=True)
else:
await interaction.response.defer(thinking=True)
pagen = int(value.lstrip('- '))
if value.startswith('-'):
pagen = -1 * pagen
elif pagen > 0:
pagen = pagen - 1
self.pagen = pagen
await self.refresh(thinking=interaction)
async def jump_button_refresh(self):
component = self.jump_button
component.label = f"{self.pagen + 1}/{self.page_count}"
component.disabled = (self.page_count <= 1)
# Forward
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
if self.child_ticket:
await self.child_ticket.quit()
await self.quit()
# ----- UI Flow -----
def _format_ticket(self, ticket) -> str:
"""
Format a ticket into a single embed line.
"""
components = (
"[#{ticketid}]({link})",
"{created}",
"`{type}[{state}]`",
"<@{targetid}>",
"{content}",
)
formatstr = ' | '.join(components)
data = ticket.data
if not data.content:
content = 'No Content'
elif len(data.content) > 100:
content = data.content[:97] + '...'
else:
content = data.content
ticketstr = formatstr.format(
ticketid=data.guild_ticketid,
link=ticket.jump_url or 'https://lionbot.org',
created=discord.utils.format_dt(data.created_at, 'd'),
type=data.ticket_type.name,
state=data.ticket_state.name,
targetid=data.targetid,
content=content,
)
if data.ticket_state is TicketState.PARDONED:
ticketstr = f"~~{ticketstr}~~"
return ticketstr
async def make_message(self) -> MessageArgs:
t = self.bot.translator.t
embed = discord.Embed(
title=t(_p(
'ui:tickets|embed|title',
"Moderation Tickets in {guild}"
)).format(guild=self.guild.name),
timestamp=utc_now()
)
tickets = self.current_page
if tickets:
desc = '\n'.join(self._format_ticket(ticket) for ticket in tickets)
else:
desc = t(_p(
'ui:tickets|embed|desc:no_tickets',
"No tickets matching the given criteria!"
))
embed.description = desc
filterstr = self.filters.formatted()
if filterstr:
embed.add_field(
name=t(_p(
'ui:tickets|embed|field:filters|name',
"Filters"
)),
value=filterstr,
inline=False
)
return MessageArgs(embed=embed)
async def refresh_layout(self):
to_refresh = (
self.edit_filter_button_refresh(),
self.select_ticket_button_refresh(),
self.pardon_button_refresh(),
self.tickets_menu_refresh(),
self.filter_type_menu_refresh(),
self.filter_state_menu_refresh(),
self.filter_target_menu_refresh(),
self.jump_button_refresh(),
)
await asyncio.gather(*to_refresh)
action_line = (
self.edit_filter_button,
self.select_ticket_button,
self.pardon_button,
)
if self.page_count > 1:
page_line = (
self.prev_button,
self.jump_button,
self.quit_button,
self.next_button,
)
else:
page_line = ()
action_line = (*action_line, self.quit_button)
if self.show_filters:
menus = (
(self.filter_type_menu,),
(self.filter_state_menu,),
(self.filter_target_menu,),
)
elif self.show_tickets and self.current_page:
menus = ((self.tickets_menu,),)
else:
menus = ()
self.set_layout(
action_line,
*menus,
page_line,
)
async def reload(self):
tickets = await Ticket.fetch_tickets(
self.bot,
*self.filters.conditions(),
guildid=self.guild.id,
)
blocks = [
tickets[i:i+self.block_len]
for i in range(0, len(tickets), self.block_len)
]
self.blocks = blocks or [[]]
class TicketUI(MessageUI):
def __init__(self, bot: LionBot, ticket: Ticket, callerid: int, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.ticket = ticket
# ----- API -----
# ----- UI Components -----
# Pardon Ticket
@button(
label="PARDON_BUTTON_PLACEHOLDER",
style=ButtonStyle.red
)
async def pardon_button(self, press: discord.Interaction, pressed: Button):
t = self.bot.translator.t
modal_title = t(_p(
'ui:ticket|button:pardon|modal:reason|title',
"Pardon Moderation Ticket"
))
input_field = TextInput(
label=t(_p(
'ui:ticket|button:pardon|modal:reason|field|label',
"Why are you pardoning this ticket?"
)),
style=TextStyle.long,
min_length=0,
max_length=1024,
)
try:
interaction, reason = await input(
press, modal_title, field=input_field, timeout=300,
)
except asyncio.TimeoutError:
raise ResponseTimedOut
await interaction.response.defer(thinking=True, ephemeral=True)
await self.ticket.pardon(modid=press.user.id, reason=reason)
await self.refresh(thinking=interaction)
async def pardon_button_refresh(self):
button = self.pardon_button
t = self.bot.translator.t
button.label = t(_p(
'ui:ticket|button:pardon|label',
"Pardon"
))
button.disabled = (self.ticket.data.ticket_state is TicketState.PARDONED)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
return await self.ticket.make_message()
async def refresh_layout(self):
await self.pardon_button_refresh()
self.set_layout(
(self.pardon_button, self.quit_button,)
)
async def reload(self):
await self.ticket.data.refresh()

View File

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import ProfileCog
async def setup(bot):
await bot.add_cog(ProfileCog(bot))

455
src/modules/profiles/cog.py Normal file
View File

@@ -0,0 +1,455 @@
import asyncio
from enum import Enum
from typing import Optional, overload
from datetime import timedelta
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from twitchAPI.helper import first
from twitchAPI.type import AuthScope
import twitchio
from twitchAPI.object.api import TwitchUser
from data.queries import ORDER
from meta import LionCog, LionBot, LionContext
from meta.logger import log_wrap
from utils.lib import utc_now
from . import logger
from .data import ProfileData
from .profile import UserProfile
from .community import Community
from .ui import TwitchLinkStatic, TwitchLinkFlow
class ProfileCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(ProfileData())
self._profile_migrators = {}
self._comm_migrators = {}
async def cog_load(self):
await self.data.init()
self.bot.add_view(TwitchLinkStatic(timeout=None))
async def cog_check(self, ctx):
return True
# Profile API
def add_profile_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._profile_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added user profile migrator {name}: {migrator}"
)
return migrator
def del_profile_migrator(self, name: str):
migrator = self._profile_migrators.pop(name, None)
logger.info(
f"Removed user profile migrator {name}: {migrator}"
)
@log_wrap(action="profile migration")
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
logger.info(
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._profile_migrators.items():
try:
result = await migrator(source_profile, target_profile)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running user profile migrator {name} "
f"migrating {source_profile!r} to {target_profile!r}."
)
raise
# Move all Discord and Twitch profile references over to the new profile
discord_rows = await self.data.DiscordProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
# And then mark the old profile as migrated
await source_profile.profile_row.update(migrated=target_profile.profileid)
results.append("Marking old profile as migrated.. finished!")
return results
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
"""
Fetch a UserProfile by the given id.
"""
return await UserProfile.fetch(self.bot, profile_id=profile_id)
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided discord account.
"""
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_discord(self.bot, user)
return profile
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided twitch account.
"""
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_twitch(self.bot, user)
return profile
# Community API
def add_community_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._comm_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added community migrator {name}: {migrator}"
)
return migrator
def del_community_migrator(self, name: str):
migrator = self._comm_migrators.pop(name, None)
logger.info(
f"Removed community migrator {name}: {migrator}"
)
@log_wrap(action="community migration")
async def migrate_community(self, source_comm, target_comm) -> list[str]:
logger.info(
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._comm_migrators.items():
try:
result = await migrator(source_comm, target_comm)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running community migrator {name} "
f"migrating {source_comm!r} to {target_comm!r}."
)
raise
# Move all Discord and Twitch community preferences over to the new profile
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
profileid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
communityid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
# And then mark the old community as migrated
await source_comm.update(migrated=target_comm.communityid)
results.append("Marking old community as migrated.. finished!")
return results
async def fetch_community_by_id(self, community_id: int) -> Community:
"""
Fetch a Community by the given id.
"""
return await Community.fetch(self.bot, community_id=community_id)
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
"""
Fetch or create a Community from the provided discord guild.
"""
comm = await Community.fetch_from_discordid(self.bot, guild.id)
if comm is None:
comm = await Community.create_from_discord(self.bot, guild)
return comm
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
"""
Fetch or create a Community from the provided twitch account.
"""
community = await Community.fetch_from_twitchid(self.bot, user.id)
if community is None:
community = await Community.create_from_twitch(self.bot, user)
return community
# ----- Admin Commands -----
@cmds.hybrid_command(
name='linkoffer',
description="Send a message with a permanent button for profile linking"
)
@appcmds.default_permissions(manage_guild=True)
async def linkoffer_cmd(self, ctx: LionContext):
view = TwitchLinkStatic(timeout=None)
await ctx.channel.send(embed=view.embed, view=view)
# ----- Profile Commands -----
@cmds.hybrid_group(
name='profiles',
description="Base comand group for user profiles."
)
async def profiles_grp(self, ctx: LionContext):
...
@profiles_grp.group(
name='link',
description="Base command group for linking profiles"
)
async def profiles_link_grp(self, ctx: LionContext):
...
@profiles_link_grp.command(
name='twitch',
description="Link a twitch account to your current profile."
)
async def profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
flowui = TwitchLinkFlow(self.bot, ctx.author, callerid=ctx.author.id)
await flowui.run(ctx.interaction)
await flowui.wait()
async def old_profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
await ctx.interaction.response.defer(ephemeral=True)
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth()
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
"to Twitch."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning profile merge..."
)
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
# if not results:
# logger.error(
# f"User {authrow} obtained from Twitch authentication does not exist."
# )
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
# return
# user = results[0]
try:
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
except Exception:
logger.error(
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
exc_info=True
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
if user is None:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
# Retrieve author's profile if it exists
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
# Check if the twitch-side user has a profile
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if author_profile and source_profile is None:
# All we need to do is attach the twitch row
await author_profile.attach_twitch(user.id)
await message.edit(
content=f"Successfully added Twitch account **{user.display_name}**! There was no profile data to merge."
)
elif source_profile and author_profile is None:
# Attach the discord row to the profile
await source_profile.attach_discord(ctx.author.id)
await message.edit(
content=f"Successfully connected to Twitch profile **{user.display_name}**! There was no profile data to merge."
)
elif source_profile is None and author_profile is None:
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
await profile.attach_twitch(user.id)
await message.edit(
content=f"Opened a new user profile for you and linked Twitch account **{user.display_name}**."
)
elif author_profile.profileid == source_profile.profileid:
await message.edit(
content=f"The Twitch account **{user.display_name}** is already linked to your profile!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_profile(source_profile, author_profile)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your account profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging profiles...",
*results,
"**Successfully linked account and merged profile data!**"
))
await message.edit(content=content)
# ----- Community Commands -----
@cmds.hybrid_group(
name='community',
description="Base comand group for community profiles."
)
async def community_grp(self, ctx: LionContext):
...
@community_grp.group(
name='link',
description="Base command group for linking communities"
)
async def community_link_grp(self, ctx: LionContext):
...
@community_link_grp.command(
name='twitch',
description="Link a twitch account to this community."
)
@appcmds.guild_only()
@appcmds.default_permissions(manage_guild=True)
async def comm_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
assert ctx.guild is not None
await ctx.interaction.response.defer(ephemeral=True)
if not ctx.author.guild_permissions.manage_guild:
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
return
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth(
scopes=[
AuthScope.CHAT_EDIT,
AuthScope.CHAT_READ,
AuthScope.MODERATION_READ,
AuthScope.CHANNEL_BOT,
]
)
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning community profile merge..."
)
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
# if not results:
# logger.error(
# f"User {authrow} obtained from Twitch authentication does not exist."
# )
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
# return
# user = results[0]
try:
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
except Exception:
logger.error(
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
exc_info=True
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
if user is None:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
# Retrieve author's profile if it exists
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
# Check if the twitch-side user has a profile
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
if guild_comm and twitch_comm is None:
# All we need to do is attach the twitch row
await guild_comm.attach_twitch(user.id)
await message.edit(
content=f"Successfully linked Twitch channel **{user.display_name}**! There was no community data to merge."
)
elif twitch_comm and guild_comm is None:
# Attach the discord row to the profile
await twitch_comm.attach_discord(ctx.guild.id)
await message.edit(
content=f"Successfully connected to Twitch channel **{user.display_name}**!"
)
elif twitch_comm is None and guild_comm is None:
profile = await Community.create_from_discord(self.bot, ctx.guild)
await profile.attach_twitch(user.id)
await message.edit(
content=f"Created a new community for this server and linked Twitch account **{user.display_name}**."
)
elif guild_comm.communityid == twitch_comm.communityid:
await message.edit(
content=f"This server is already linked to the Twitch channel **{user.display_name}**!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_community(twitch_comm, guild_comm)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your community profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging community profiles...",
*results,
"**Successfully linked account and merged community data!**"
))
await message.edit(content=content)

View File

@@ -0,0 +1,123 @@
from typing import Optional, Self
import discord
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class Community:
def __init__(self, bot: LionBot, community_row):
self.bot = bot
self.row: ProfileData.CommunityRow = community_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def communityid(self):
return self.row.communityid
def __repr__(self):
return f"<Community communityid={self.communityid} row={self.row}>"
async def attach_discord(self, guildid: int):
"""
Attach a new discord guild to this community.
Assumes the discord guild is not already associated to a community.
"""
discord_row = await self.data.DiscordCommunityRow.create(
communityid=self.communityid,
guildid=guildid
)
logger.info(
f"Attached discord guild {guildid} to community {self!r}"
)
return discord_row
async def attach_twitch(self, channelid: str):
"""
Attach a new Twitch user channel to this community.
"""
twitch_row = await self.data.TwitchCommunityRow.create(
communityid=self.communityid,
channelid=str(channelid)
)
logger.info(
f"Attached twitch channel {channelid} to community {self!r}"
)
return twitch_row
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
"""
Fetch the Discord guild rows associated to this community.
"""
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
"""
Fetch the Twitch user rows associated to this profile.
"""
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
@classmethod
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
if community_row is None:
raise ValueError("Provided community_id does not exist.")
return cls(bot, community_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty community with the given initial arguments.
Communities should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial preferences (e.g. name, avatar).
"""
# Create a new community
data = bot.get_cog('ProfileCog').data
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
return await cls.fetch(bot, row.communityid)
@classmethod
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
"""
Create a new community using the given Discord guild as a base.
"""
self = await cls.create(bot, **kwargs)
await self.attach_discord(guild.id)
return self
@classmethod
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
"""
Create a new profile using the given Twitch channel user as a base.
The provided `user` must have an `id` attribute.
"""
self = await cls.create(bot, **kwargs)
await self.attach_twitch(str(user.id))
return self

View File

@@ -0,0 +1,158 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class ProfileData(Registry):
class UserProfileRow(RowModel):
"""
Schema
------
CREATE TABLE user_profiles(
profileid SERIAL PRIMARY KEY,
nickname TEXT,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_profiles'
_cache_ = {}
profileid = Integer(primary=True)
nickname = String()
migrated = Integer()
created_at = Timestamp()
class DiscordProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_discord(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
"""
_tablename_ = 'profiles_discord'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = Integer()
created_at = Integer()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class TwitchProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_twitch(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
"""
_tablename_ = 'profiles_twitch'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = String()
created_at = Timestamp()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class CommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities(
communityid SERIAL PRIMARY KEY,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities'
_cache_ = {}
communityid = Integer(primary=True)
migrated = Integer()
created_at = Timestamp()
class DiscordCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_discord(
guildid BIGINT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_discord'
_cache_ = {}
guildid = Integer(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class TwitchCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_twitch(
channelid TEXT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_twitch'
_cache_ = {}
channelid = String(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class CommunityMemberRow(RowModel):
"""
Schema
------
CREATE TABLE community_members(
memberid SERIAL PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
"""
_tablename_ = 'community_members'
_cache_ = {}
memberid = Integer(primary=True)
communityid = Integer()
profileid = Integer()
created_at = Timestamp()

View File

@@ -0,0 +1,138 @@
from typing import Optional, Self
import discord
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class UserProfile:
def __init__(self, bot: LionBot, profile_row):
self.bot = bot
self.profile_row: ProfileData.UserProfileRow = profile_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def profileid(self):
return self.profile_row.profileid
def __repr__(self):
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
async def get_name(self) -> Optional[str]:
return self.profile_row.nickname
async def attach_discord(self, userid: int):
"""
Attach a new discord user to this profile.
Assumes the discord user does not itself have a profile.
"""
discord_row = await self.data.DiscordProfileRow.create(
profileid=self.profileid,
userid=userid
)
logger.info(
f"Attached discord user {userid} to profile {self!r}"
)
return discord_row
async def attach_twitch(self, userid: str):
"""
Attach a new Twitch user to this profile.
"""
twitch_row = await self.data.TwitchProfileRow.create(
profileid=self.profileid,
userid=userid
)
logger.info(
f"Attached twitch user {userid} to profile {self!r}"
)
return twitch_row
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
"""
Fetch the Discord accounts associated to this profile.
"""
return await self.data.DiscordProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
"""
Fetch the Twitch accounts associated to this profile.
"""
return await self.data.TwitchProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
@classmethod
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
if profile_row is None:
raise ValueError("Provided profile_id does not exist.")
return cls(bot, profile_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty profile with the given initial arguments.
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial profile preferences (e.g. name, avatar).
"""
# Create a new profile
data = bot.get_cog('ProfileCog').data
profile_row = await data.UserProfileRow.create(created_at=utc_now())
profile = await cls.fetch(bot, profile_row.profileid)
return profile
@classmethod
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
"""
Create a new profile using the given Discord user as a base.
"""
kwargs.setdefault('nickname', user.name)
profile = await cls.create(bot, **kwargs)
await profile.attach_discord(user.id)
return profile
@classmethod
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
"""
Create a new profile using the given Twitch user as a base.
Assumes the provided `user` has `id` and `name` attributes.
"""
kwargs.setdefault('nickname', user.name)
profile = await cls.create(bot, **kwargs)
await profile.attach_twitch(str(user.id))
return profile

View File

@@ -0,0 +1 @@
from .twitchlink import TwitchLinkStatic, TwitchLinkFlow

View File

@@ -0,0 +1,337 @@
"""
UI Views for Twitch linkage.
- Persistent view with interaction-button to enter link flow.
- We don't store the view, but we listen to interaction button id.
- Command to enter link flow.
For link flow, send ephemeral embed with instructions and what to expect, with link button below.
After auth is granted through OAuth flow (or if not granted, e.g. on timeout or failure)
edit the embed to reflect auth situation.
If migration occurred, add the migration text as a field to the embed.
"""
import asyncio
from datetime import timedelta
from enum import IntEnum
from typing import Optional
import aiohttp
import discord
from discord.ui.button import button, Button
from discord.enums import ButtonStyle
from twitchAPI.helper import first
from meta import LionBot
from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils.ui import MessageUI
from utils.lib import MessageArgs, utc_now
from utils.ui.leo import LeoUI
from modules.profiles.profile import UserProfile
import brand
from .. import logger
class TwitchLinkStatic(LeoUI):
"""
Static UI whose only job is to display a persistent button
to ask people to connect their twitch account.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._embed: Optional[discord.Embed] = None
print("INITIALISATION")
async def interaction_check(self, interaction: discord.Interaction):
return True
@property
def embed(self) -> discord.Embed:
"""
This is the persistent message people will see with the button that starts the Oauth flow.
Not sure what this should actually say, or whether it should be customisable via command.
:TODO-MARKER:
"""
embed = discord.Embed(
title="Link your Twitch account!",
description=(
"To participate in the Dreamspace Adventure Game :TM:, "
"please start by pressing the button below to begin the login flow with Twitch!"
),
colour=brand.ACCENT_COLOUR,
)
return embed
@embed.setter
def embed(self, value):
self._embed = value
@button(label="Connect", custom_id="BTN-LINK-TWITCH", style=ButtonStyle.green, emoji='🔗')
@log_wrap(action="link-twitch-btn")
async def button_linker(self, interaction: discord.Interaction, btn: Button):
# Here we just reply to the interaction with the AuthFlow UI
# TODO
print("RESPONDING")
flowui = TwitchLinkFlow(interaction.client, interaction.user, callerid=interaction.user.id)
await flowui.run(interaction)
await flowui.wait()
class FlowState(IntEnum):
SETUP = -1
WAITING = 0
CANCELLED = 1
TIMEOUT = 2
ERRORED = 3
WORKING = 9
DONE = 10
class TwitchLinkFlow(MessageUI):
def __init__(self, bot: LionBot, caller: discord.User | discord.Member, *args, **kwargs):
kwargs.setdefault('callerid', caller.id)
super().__init__(*args, **kwargs)
self.bot = bot
self._auth_task = None
self._stage: FlowState = FlowState.SETUP
self.flow = None
self.authrow = None
self.user = caller
self._info = None
self._migration_details = None
# ----- UI API -----
async def run(self, interaction: discord.Interaction, **kwargs):
await interaction.response.defer(ephemeral=True, thinking=True)
await self._start_flow()
await self.draw(interaction, **kwargs)
if self._stage is FlowState.ERRORED:
# This can happen if starting the flow failed
await self.close()
@log_wrap(action="start-twitch-flow-ui")
async def _start_flow(self):
logger.info(f"Starting twitch authentication flow for {self.user}")
try:
self.flow = await self.bot.get_cog('TwitchAuthCog').start_auth()
except aiohttp.ClientError:
self._stage = FlowState.ERRORED
self._info = (
"Could not establish a connection to the authentication server! "
"Please try again later~"
)
logger.exception("Unexpected exception while starting authentication flow!", exc_info=True)
else:
self._stage = FlowState.WAITING
self._auth_task = asyncio.create_task(self._auth_flow())
@log_wrap(action="run-twitch-flow-ui")
async def _auth_flow(self):
"""
Run the flow and wait for a timeout, cancellation, or callback.
Update the message accordingly.
"""
assert self.flow is not None
try:
# TODO: Cancel this in cleanup
authrow = await asyncio.wait_for(self.flow.run(), timeout=60)
except asyncio.TimeoutError:
self._stage = FlowState.TIMEOUT
# Link Timed Out!
self._info = (
"We didn't receive a response so we closed the uplink "
"to keep your account safe! If you still want to connect, please try again!"
)
await self.refresh()
await self.close()
except asyncio.CancelledError:
# Presumably the user exited or the bot is shutting down.
# Not safe to edit the message, but try and cleanup
await self.close()
except SafeCancellation as e:
logger.info("User or server cancelled authentication flow: ", exc_info=True)
# Uplink Cancelled!
self._info = (
f"We couldn't complete the uplink!\nReason:*{e.msg}*"
)
self._stage = FlowState.CANCELLED
await self.refresh()
await self.close()
except Exception:
logger.exception("Something unexpected went wrong while running the flow!")
else:
self._stage = FlowState.WORKING
self._info = (
"Authentication complete! Connecting your Dreamspace account ...."
)
self.authrow = authrow
await self.refresh()
await self._link_twitch(str(authrow.userid))
await self.refresh()
await self.close()
async def cleanup(self):
await super().cleanup()
if self._auth_task and not self._auth_task.cancelled():
self._auth_task.cancel()
async def _link_twitch(self, twitch_id: str):
"""
Link the caller's profile to the given twitch_id.
Performs migration if needed.
"""
try:
twitch_user = await first(self.bot.twitch.get_users(user_ids=[twitch_id]))
except Exception:
logger.exception(
f"Looking up user {self.authrow} from Twitch authentication flow raised an error."
)
self._stage = FlowState.ERRORED
self._info = "Failed to look up your user details from Twitch! Please try again later."
return
if twitch_user is None:
logger.error(
f"User {self.authrow} obtained from Twitch authentication does not exist."
)
self._stage = FlowState.ERRORED
self._info = "Authentication failed! Please try again later."
return
profiles = self.bot.get_cog('ProfileCog')
userid = self.user.id
caller_profile = await UserProfile.fetch_from_discordid(self.bot, userid)
twitch_profile = await UserProfile.fetch_from_twitchid(self.bot, twitch_id)
succ_info = (
f"Successfully established uplink to your Twitch account **{twitch_user.display_name}** "
"and transferred dreamspace data! Happy adventuring, and watch out for the grue~"
)
# ::TODO-MARKER::
if twitch_profile is None:
if caller_profile is None:
# Neither profile exists
profile = await UserProfile.create_from_discord(self.bot, self.user)
await profile.attach_twitch(twitch_id)
self._stage = FlowState.DONE
self._info = succ_info
else:
await caller_profile.attach_twitch(twitch_id)
self._stage = FlowState.DONE
self._info = succ_info
else:
if caller_profile is None:
await twitch_profile.attach_discord(self.user.id)
self._stage = FlowState.DONE
self._info = succ_info
elif twitch_profile.profileid == caller_profile.profileid:
self._stage = FlowState.CANCELLED
self._info = (
f"The Twitch account **{twitch_user.display_name}** is already linked to your profile!"
)
else:
# In this case we have conflicting profiles we need to migrate
try:
results = await profiles.migrate_profile(twitch_profile, caller_profile)
except Exception:
self._stage = FlowState.ERRORED
self._info = (
"An issue was encountered while merging your account profiles! "
"The migration was rolled back, and not data has been lost.\n"
"The developer has been notified, please try again later!"
)
logger.exception(f"Failed to migrate profiles {twitch_profile=} to {caller_profile=}")
else:
self._stage = FlowState.DONE
self._info = succ_info
self._migration_details = '\n'.join(results)
logger.info(
f"Migrated {twitch_profile=} to {caller_profile}. Info: {self._migration_details}"
)
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
if self._stage is FlowState.SETUP:
raise ValueError("Making message before flow initialisation!")
assert self.flow is not None
if self._stage is FlowState.WAITING:
# Message should be the initial request page
dur = discord.utils.format_dt(utc_now() + timedelta(seconds=60), style='R')
title = "Press the button to login!"
desc = (
"We have generated a custom secure link for you to connect your Twitch profile! "
"Press the button below and accept the connection in your browser, "
"and we will begin the transfer!\n"
f"(Note: The link expires {dur})"
)
colour = brand.ACCENT_COLOUR
elif self._stage is FlowState.CANCELLED:
# Show cancellation message
# Show 'you can close this'
title = "Uplink Cancelled!"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.TIMEOUT:
title = "Link Timed Out"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.ERRORED:
title = "Something went wrong!"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.WORKING:
# We've received the auth, we are now doing migration
title = "Establishing Connection"
desc = self._info
colour = brand.ACCENT_COLOUR
elif self._stage is FlowState.DONE:
title = "Success!"
desc = self._info
colour = discord.Colour.brand_green()
else:
raise ValueError(f"Invalid stage value {self._stage}")
embed = discord.Embed(title=title, description=desc, colour=colour, timestamp=utc_now())
if self._migration_details:
embed.add_field(
name="Profile migration details",
value=self._migration_details
)
return MessageArgs(embed=embed)
async def refresh_layout(self):
# If we haven't received the auth callback yet, make the flow link button
if self.flow is None:
raise ValueError("Refreshing before flow initialisation!")
if self._stage <= FlowState.WAITING:
flow_link = self.flow.auth.return_auth_url()
button = Button(
style=ButtonStyle.link,
url=flow_link,
label="Login With Twitch"
)
self.set_layout((button,))
else:
self.set_layout(())
async def reload(self):
pass

16
src/routes/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
from .stamps import routes as stamp_routes
from .documents import routes as doc_routes
from .users import routes as user_routes
from .specimens import routes as spec_routes
from .transactions import routes as txn_routes
from .events import routes as event_routes
from .lib import dbvar, datamodelsv, profiledatav
def register_routes(router):
router.add_routes(stamp_routes)
router.add_routes(doc_routes)
router.add_routes(user_routes)
router.add_routes(spec_routes)
router.add_routes(event_routes)
router.add_routes(txn_routes)

370
src/routes/documents.py Normal file
View File

@@ -0,0 +1,370 @@
"""
- `/documents` with `POST, GET`
- `/documents/{document_id}` with `GET`, `PATCH`, `DELETE`
- `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set.
"""
import logging
import binascii
from datetime import datetime
from typing import Any, NamedTuple, Optional, Self, TypedDict, Unpack, reveal_type, List
from aiohttp import web
import discord
from data import Condition, condition
from data.queries import JOINTYPE
from datamodels import DataModel
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv, event_log
from .stamps import Stamp, StampCreateParams, StampEditParams, StampPayload
routes = web.RouteTableDef()
logger = logging.getLogger(__name__)
class DocPayload(TypedDict):
document_id: int
document_data: str
seal: int
created_at: str
metadata: Optional[str]
stamps: List[StampPayload]
class DocCreateParamsReq(TypedDict, total=True):
document_data: str
seal: int
class DocCreateParams(DocCreateParamsReq, total=False):
metadata: Optional[str]
stamps: List[StampCreateParams]
class DocEditParams(TypedDict, total=False):
document_data: str
seal: int
metadata: Optional[str]
stamps: List[StampCreateParams]
fields = [
ModelField('document_id', int, False, False, False),
ModelField('document_data', str, True, True, True),
ModelField('seal', int, True, True, True),
ModelField('created_at', str, False, False, False),
ModelField('metadata', Optional[str], False, True, True),
ModelField('stamps', List[StampCreateParams], False, True, True),
]
req_fields = {field.name for field in fields if field.required}
edit_fields = {field.name for field in fields if field.can_edit}
create_fields = {field.name for field in fields if field.can_create}
class Document:
def __init__(self, app: web.Application, row: DataModel.Document):
self.app = app
self.data = app[datamodelsv]
self.row = row
@classmethod
async def validate_create_params(cls, params):
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to document creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"Document params missing required key '{missing}'")
@classmethod
async def fetch_from_id(cls, app: web.Application, document_id: int) -> Optional[Self]:
data = app[datamodelsv]
row = await data.Document.fetch(document_id)
return cls(app, row) if row is not None else None
@classmethod
async def query(
cls,
app: web.Application,
document_id: Optional[int] = None,
seal: Optional[int] = None,
created_before: Optional[str] = None,
created_after: Optional[str] = None,
metadata: Optional[str] = None,
stamp_type: Optional[str] = None,
) -> List[Self]:
data = app[datamodelsv]
Doc = data.Document
conds = []
if document_id is not None:
conds.append(Doc.document_id == int(document_id))
if seal is not None:
conds.append(Doc.seal == int(seal))
if created_before is not None:
cbefore = datetime.fromisoformat(created_before)
conds.append(Doc.created_at <= cbefore)
if created_after is not None:
cafter = datetime.fromisoformat(created_after)
conds.append(Doc.created_at >= cafter)
if metadata is not None:
conds.append(Doc.metadata == metadata)
query = data.Document.table.fetch_rows_where(*conds)
# results = await query
# query = data.Document.table.select_where(*conds)
if stamp_type is not None:
query.join('document_stamps', using=('document_id',), join_type=JOINTYPE.LEFT)
query.join(
'stamp_types',
on=(data.DocumentStamp.stamp_type == data.StampType.stamp_type_id)
)
query.where(data.StampType.stamp_type_name == stamp_type)
query.select(docid = "DISTINCT(document_id)")
query.with_no_adapter()
results = await query
ids = [result['docid'] for result in results]
if ids:
rows = await data.Document.table.fetch_rows_where(
document_id=ids
)
else:
rows = []
else:
rows = await query
return [cls(app, row) for row in sorted(rows, key=lambda row:row.created_at)]
@classmethod
async def create(cls, app: web.Application, **kwargs: Unpack[DocCreateParams]) -> Self:
data = app[datamodelsv]
document_data = kwargs['document_data']
seal = kwargs['seal']
stamp_params = kwargs.get('stamps', [])
metadata = kwargs.get('metadata')
# Build the document first
row = await data.Document.create(
document_data=document_data,
seal=seal,
metadata=metadata,
)
# Then build the stamps
stamps = []
for stampdata in stamp_params:
stampdata.setdefault('document_id', row.document_id)
stamp = await Stamp.create(app, **stampdata)
stamps.append(stamp)
self = cls(app, row)
# await self.log_create()
return self
async def get_stamps(self) -> List[Stamp]:
stamprows = await self.data.DocumentStamp.table.fetch_rows_where(document_id=self.row.document_id).order_by('stamp_id')
return [Stamp(self.app, row) for row in stamprows]
async def log_create(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Document #{self.row.document_id} Created!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def log_edit(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Document #{self.row.document_id} Updated!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def event_log_args(self, with_image=True) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
embed.set_footer(text='Created At')
args: dict = {'embed': embed}
if with_image:
try:
imagedata = self.row.to_bytes()
imagedata.seek(0)
embed.set_image(url='attachment://document.png')
args['files'] = [discord.File(imagedata, "document.png")]
except binascii.Error:
# Could not decode base64
embed.add_field(name='Image', value="Could not decode document data!")
else:
embed.add_field(name='Image', value="Failed to send image!")
return MessageArgs(**args)
async def tabulate(self):
"""
Present the Document as a discord-readable table.
"""
stamps = await self.get_stamps()
typnames = []
if stamps:
typs = {stamp.row.stamp_type for stamp in stamps}
for typ in typs:
# Stamp types should be cached so this isn't expensive
typrow = await self.data.StampType.fetch(typ)
typnames.append(typrow.stamp_type_name)
table = {
'document_id': f"`{self.row.document_id}`",
'seal': str(self.row.seal),
'metadata': f"`{self.row.metadata}`" if self.row.metadata else "No metadata",
'stamps': ', '.join(f"`{name}`" for name in typnames) if typnames else "No stamps",
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
return tabulate(*table.items())
async def prepare(self) -> DocPayload:
stamps = await self.get_stamps()
results: DocPayload = {
'document_id': self.row.document_id,
'document_data': self.row.document_data,
'seal': self.row.seal,
'created_at': self.row.created_at.isoformat(),
'metadata': self.row.metadata,
'stamps': [await stamp.prepare() for stamp in stamps]
}
return results
async def edit(self, **kwargs: Unpack[DocEditParams]):
data = self.data
row = self.row
# Update the row data
# If stamps are given, delete the existing ones
# Then write in the new ones.
update_args = {}
for key in {'document_data', 'seal', 'metadata'}:
if key in kwargs:
update_args[key] = kwargs[key]
if update_args:
await self.row.update(**update_args)
# TODO: Should really be in a transaction
# Actually each handler should be in a transaction
if new_stamps := kwargs.get('stamps', []):
await self.data.DocumentStamp.table.delete_where(document_id=self.row.document_id)
for stampdata in new_stamps:
stampdata.setdefault('document_id', row.document_id)
await Stamp.create(self.app, **stampdata)
await self.log_edit()
async def delete(self) -> DocPayload:
payload = await self.prepare()
await self.row.delete()
return payload
@routes.view('/documents')
@routes.view('/documents/')
class DocumentsView(web.View):
async def get(self):
request = self.request
filter_params = {}
keys = [
'document_id',
'seal',
'created_before',
'created_after',
'metadata',
'stamp_type',
]
for key in keys:
if key in request.query:
filter_params[key] = request.query[key]
elif key in request:
filter_params[key] = request[key]
documents = await Document.query(request.app, **filter_params)
payload = [await doc.prepare() for doc in documents]
return web.json_response(payload)
async def post(self):
request = self.request
params = await request.json()
for key in create_fields:
if key in request:
params.setdefault(key, request[key])
await Document.validate_create_params(params)
document = await Document.create(self.request.app, **params)
payload = await document.prepare()
return web.json_response(payload)
@routes.view('/documents/{document_id}')
@routes.view('/documents/{document_id}/')
class DocumentView(web.View):
async def resolve_document(self):
request = self.request
document_id = request.match_info['document_id']
document = await Document.fetch_from_id(request.app, int(document_id))
if document is None:
raise web.HTTPNotFound(text="No document exists with the given ID.")
return document
async def get(self):
doc = await self.resolve_document()
payload = await doc.prepare()
return web.json_response(payload)
async def patch(self):
doc = await self.resolve_document()
params = await self.request.json()
edit_data = {}
for key, value in params.items():
if key not in edit_fields:
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Document!")
edit_data[key] = value
for key in edit_fields:
if key in self.request:
edit_data.setdefault(key, self.request[key])
await doc.edit(**edit_data)
payload = await doc.prepare()
return web.json_response(payload)
async def delete(self):
doc = await self.resolve_document()
payload = await doc.delete()
return web.json_response(payload)
# We have one prefix route, /documents/{document_id}/stamps
@routes.route('*', "/documents/{document_id}{tail:/stamps}")
@routes.route('*', "/documents/{document_id}{tail:/stamps/.*}")
async def document_stamps_route(request: web.Request):
document_id = int(request.match_info['document_id'])
document = await Document.fetch_from_id(request.app, document_id)
if document is None:
raise web.HTTPNotFound(text="No document exists with the given ID.")
new_path = request.match_info['tail']
logger.info(f"Redirecting {request=} to {new_path=} and setting {document_id=}")
new_request = request.clone(rel_url=new_path)
new_request['document_id'] = document_id
match_info = await request.app.router.resolve(new_request)
match_info.current_app = request.app
new_request._match_info = match_info
if match_info.handler:
return await match_info.handler(new_request)
else:
raise web.HTTPNotFound()

505
src/routes/events.py Normal file
View File

@@ -0,0 +1,505 @@
import binascii
import logging
from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
from aiohttp import web
import discord
from data import Condition, condition
from data.conditions import NULL
from data.queries import JOINTYPE
from datamodels import DataModel, EventType
from modules.profiles.data import ProfileData
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
from .specimens import Specimen, SpecimenPayload
routes = web.RouteTableDef()
logger = logging.getLogger(__name__)
class Event:
def __init__(self, app: web.Application, row: DataModel.EventDetails):
self.app = app
self.data = app[datamodelsv]
self.row = row
@classmethod
async def fetch_from_id(cls, app: web.Application, event_id: int):
data = app[datamodelsv]
row = await data.EventDetails.fetch(int(event_id))
return cls(app, row) if row is not None else None
@classmethod
async def query(
cls,
app: web.Application,
event_id: Optional[str] = None,
document_id: Optional[str] = None,
document_seal: Optional[str] = None,
user_id: Optional[str] = None,
user_name: Optional[str] = None,
occurred_before: Optional[str] = None,
occurred_after: Optional[str] = None,
created_before: Optional[str] = None,
created_after: Optional[str] = None,
event_type: Optional[str] = None,
) -> List[Self]:
data = app[datamodelsv]
EventD = data.EventDetails
conds = []
if event_id is not None:
conds.append(EventD.event_id == int(event_id))
if document_id is not None:
conds.append(EventD.document_id == int(document_id))
if document_seal is not None:
conds.append(EventD.document_seal == int(document_seal))
if user_id is not None:
conds.append(EventD.user_id == int(user_id))
if user_name is not None:
conds.append(EventD.user_name == user_name)
if created_before is not None:
cbefore = datetime.fromisoformat(created_before)
conds.append(EventD.created_at <= cbefore)
if created_after is not None:
cafter = datetime.fromisoformat(created_after)
conds.append(EventD.created_at >= cafter)
if occurred_before is not None:
before = datetime.fromisoformat(occurred_before)
conds.append(EventD.occurred_at <= before)
if occurred_after is not None:
after = datetime.fromisoformat(occurred_after)
conds.append(EventD.occurred_at >= after)
if event_type is not None:
ekey = (event_type.lower().strip(),)
if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{event_type}'")
conds.append(EventD.event_type == EventType(ekey))
rows = await EventD.fetch_where(*conds).order_by(EventD.occurred_at)
return [cls(app, row) for row in rows]
@classmethod
async def validate_create_params(cls, params):
if 'event_type' not in params:
raise web.HTTPBadRequest(text="Event creation missing required field 'event_type'.")
ekey = (params['event_type'].lower().strip(),)
if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{params['event_type']}'")
event_type = EventType(ekey)
req_fields = {
'user_name', 'occurred_at', 'event_type',
}
other_fields = {
'document_id', 'document',
'user_id', 'user',
}
if 'user_id' not in params and 'user' not in params:
raise web.HTTPBadRequest(text="One of 'user_id' or 'user' must be supplied to create Event.")
match event_type:
case EventType.PLAIN:
req_fields.add('message')
case EventType.SUBSCRIBER:
req_fields.add('tier')
req_fields.add('subscribed_length')
other_fields.add('message')
case EventType.CHEER:
req_fields.add('amount')
other_fields.add('cheer_type')
other_fields.add('message')
case EventType.RAID:
req_fields.add('viewer_count')
create_fields = req_fields.union(other_fields)
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to {event_type} event creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"{event_type} Event params missing required key '{missing}'")
@classmethod
async def create(cls, app: web.Application, **kwargs):
data = app[datamodelsv]
# EventD = data.EventDetails
ekey = (kwargs['event_type'].lower().strip(),)
if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{kwargs['event_type']}'")
event_type = EventType(ekey)
params = {}
typparams = {}
match event_type:
case EventType.PLAIN:
typtab = data.plain_events
typparams['message'] = kwargs['message']
case EventType.CHEER:
typtab = data.cheer_events
typparams['amount'] = kwargs['amount']
typparams['cheer_type'] = kwargs.get('cheer_type')
typparams['message'] = kwargs.get('message')
case EventType.RAID:
typtab = data.raid_events
typparams['visitor_count'] = kwargs.get('viewer_count')
case EventType.SUBSCRIBER:
typtab = data.subscriber_events
typparams['tier'] = kwargs['tier']
typparams['subscribed_length'] = kwargs['subscribed_length']
typparams['message'] = kwargs.get('message')
case _:
raise ValueError("Invalid EventType")
# TODO: This really really should be a transaction
# Create Document if required
if 'document' in kwargs:
from .documents import Document
doc_args = kwargs['document']
await Document.validate_create_params(doc_args)
doc = await Document.create(app, **doc_args)
document_id = doc.row.document_id
params['document_id'] = document_id
elif 'document_id' in kwargs:
document_id = kwargs['document_id']
params['document_id'] = document_id
# Create User if required
if 'user' in kwargs:
from .users import User
user_args = kwargs['user']
await User.validate_create_params(user_args)
user = await User.create(app, **user_args)
user_id = user.row.user_id
if 'user_id' in kwargs and not kwargs['user_id'] == user_id:
raise web.HTTPBadRequest(text="Provided 'user_id' does not match provided 'user'.")
else:
user_id = kwargs['user_id']
params['user_id'] = user_id
# Create Event row
params['event_type'] = event_type
params['user_name'] = kwargs['user_name']
params['occurred_at'] = datetime.fromisoformat(kwargs['occurred_at'])
eventrow = await data.Events.create(**params)
typparams['event_id'] = eventrow.event_id
# Create Event type row
typrow = await typtab.insert(**typparams)
details = await data.EventDetails.fetch(eventrow.event_id)
assert details is not None
self = cls(app, details)
await self.log_create()
return self
async def log_create(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Event #{self.row.event_id} Created!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def log_edit(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Event #{self.row.event_id} Updated!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def event_log_args(self, with_image=True) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
embed.set_footer(text='Created At')
args: dict = {'embed': embed}
doc = await self.get_document()
if doc is not None:
embed.add_field(
name="Document",
value='\n'.join(await doc.tabulate()),
inline=False
)
if with_image:
try:
imagedata = doc.row.to_bytes()
imagedata.seek(0)
embed.set_image(url='attachment://document.png')
args['files'] = [discord.File(imagedata, "document.png")]
except binascii.Error:
# Could not decode base64
embed.add_field(name='Image', value="Could not decode document data!")
else:
embed.add_field(name='Image', value="Failed to send image!")
return MessageArgs(**args)
async def tabulate(self):
"""
Present the Event as a discord-readable table.
"""
user = await self.get_user()
assert user is not None
table = {
'event_id': f"`{self.row.event_id}`",
'event_type': f"`{self.row.event_type}`",
'user': f"`{self.row.user_id}` (`{self.row.user_name}`)",
'document': f"`{self.row.document_id}`",
'occurred_at': discord.utils.format_dt(self.row.occurred_at, 'F'),
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
info = self.row.event_type.info()
for col, param in zip(info.detailcolumns, info.params):
value = getattr(self.row, col)
table[param] = f"`{value}`"
return tabulate(*table.items())
async def edit(self, **kwargs):
data = self.data
# EventD = data.EventDetails
if 'event_type' in kwargs:
raise web.HTTPBadRequest(text="You cannot change the type of an event after creation.")
typparams = {}
match self.row.event_type:
case EventType.PLAIN:
typtab = data.plain_events
if 'message' in kwargs:
typparams['message'] = kwargs['message']
case EventType.CHEER:
typtab = data.cheer_events
for key in ('amount', 'cheer_type', 'message'):
if key in kwargs:
typparams[key] = kwargs[key]
case EventType.RAID:
typtab = data.raid_events
if 'viewer_count' in kwargs:
typparams['visitor_count'] = 'viewer_count'
case EventType.SUBSCRIBER:
typtab = data.subscriber_events
for key in ('tier', 'subscribed_length', 'message'):
if key in kwargs:
typparams[key] = kwargs[key]
if typparams:
await typtab.update_where(event_id=self.row.event_id).set(**typparams)
await self.log_edit()
await self.row.refresh()
async def delete(self):
payload = await self.prepare()
if self.row.document_id:
await self.data.Document.table.delete_where(document_id=self.row.document_id)
await self.data.Events.table.delete_where(event_id=self.row.event_id)
await self.row.refresh()
return payload
async def get_user(self):
from .users import User
return await User.fetch_from_id(self.app, self.row.user_id)
async def get_document(self):
from .documents import Document
if self.row.document_id:
return await Document.fetch_from_id(self.app, self.row.document_id)
async def prepare(self):
row = await self.row.refresh()
assert row is not None
data = self.data
user = await self.get_user()
assert user is not None
document = await self.get_document()
payload = {
'event_id': self.row.event_id,
'document_id': self.row.document_id,
'document': await document.prepare() if document else None,
'user_id': self.row.user_id,
'user': await user.prepare(),
'user_name': self.row.user_name,
'occurred_at': self.row.occurred_at.isoformat(),
'created_at': self.row.created_at.isoformat(),
'event_type': self.row.event_type.value[0],
}
match row.event_type:
case EventType.PLAIN:
payload['message'] = row.plain_message
case EventType.SUBSCRIBER:
payload['tier'] = row.subscriber_tier
payload['subscribed_length'] = row.subscriber_length
payload['message'] = row.subscriber_message
case EventType.CHEER:
payload['amount'] = row.cheer_amount
payload['cheer_type'] = row.cheer_type
payload['message'] = row.cheer_message
case EventType.RAID:
payload['viewer_count'] = row.raid_visitor_count
return payload
@routes.view('/events')
@routes.view('/events/', name='events')
class EventsView(web.View):
async def post(self):
request = self.request
params = await request.json()
if 'user_id' in request:
params.setdefault('user_id', request['user_id'])
await Event.validate_create_params(params)
logger.info(f"Creating a new event with args: {params=}")
event = await Event.create(self.request.app, **params)
logger.debug(f"Created event: {event!r}")
payload = await event.prepare()
return web.json_response(payload)
async def get(self):
request = self.request
filter_params = {}
keys = [
'event_id', 'document_id', 'document_seal',
'user_id', 'user_name', 'occurred_before', 'occurred_after',
'created_before', 'created_after', 'event_type',
]
for key in keys:
value = request.query.get(key, request.get(key, None))
filter_params[key] = value
logger.info(f"Querying events with params: {filter_params=}")
events = await Event.query(request.app, **filter_params)
payload = [await event.prepare() for event in events]
return web.json_response(payload)
@routes.view('/events/{event_id}')
@routes.view('/events/{event_id}/', name='event')
class EventView(web.View):
async def resolve_event(self):
request = self.request
event_id = request.match_info['event_id']
event = await Event.fetch_from_id(request.app, int(event_id))
if event is None:
raise web.HTTPNotFound(text="No event exists with the given ID.")
return event
async def get(self):
event = await self.resolve_event()
logger.info(f"Received GET for event {event=}")
payload = await event.prepare()
return web.json_response(payload)
async def patch(self):
event = await self.resolve_event()
params = await self.request.json()
edit_data = {}
edit_fields = {'message', 'amount', 'cheer_type', 'viewer_count', 'tier', 'subscriber_length', 'message'}
for key, value in params.items():
if key not in edit_fields:
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
edit_data[key] = value
for key in edit_fields:
if key in self.request:
edit_data.setdefault(key, self.request[key])
logger.info(f"Received PATCH for event {event} with params: {params}")
await event.edit(**edit_data)
payload = await event.prepare()
return web.json_response(payload)
async def delete(self):
event = await self.resolve_event()
logger.info(f"Received DELETE for event {event}")
payload = await event.delete()
return web.json_response(payload)
@routes.route('*', "/events/{event_id}/user")
@routes.route('*', "/events/{event_id}/user{tail:/.*}")
async def event_user_route(request: web.Request):
event_id = int(request.match_info['event_id'])
event = await Event.fetch_from_id(request.app, event_id)
if event is None:
raise web.HTTPNotFound(text="No event exists with the given ID.")
tail = request.match_info.get('tail', '')
new_path = "/users/{user_id}".format(user_id=event.row.user_id) + tail
logger.info(f"Redirecting {request=} to {new_path}")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = event.row.user_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()
@routes.route('*', "/events/{event_id}/document")
@routes.route('*', "/events/{event_id}/document{tail:/.*}")
async def event_document_route(request: web.Request):
event_id = int(request.match_info['event_id'])
event = await Event.fetch_from_id(request.app, event_id)
if event is None:
raise web.HTTPNotFound(text="No event exists with the given ID.")
tail = request.match_info.get('tail', '')
document = await event.get_document()
if document is None:
if request.method == 'POST' and not tail:
new_path = '/documents'
logger.info(f"Redirecting {request=} to POST /documents")
new_request = request.clone(rel_url=new_path)
new_request['event_id'] = event_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
return await match_info.handler(new_request)
else:
raise web.HTTPNotFound(text="This event has no document.")
else:
document_id = document.row.document_id
# Redirect to POST /documents/{document_id}/...
new_path = f"/documents/{document_id}".format(document_id=document_id) + tail
logger.info(f"Redirecting {request=} to {new_path}")
new_request = request.clone(rel_url=new_path)
new_request['event_id'] = event_id
new_request['document_id'] = document_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()

29
src/routes/lib.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import NamedTuple, Any, Optional, Self, Unpack, List, TypedDict
from aiohttp import web, ClientSession
from discord import Webhook
from data.database import Database
from datamodels import DataModel
from modules.profiles.data import ProfileData
from meta import conf
dbvar = web.AppKey("database", Database)
datamodelsv = web.AppKey("datamodels", DataModel)
profiledatav = web.AppKey("profiledata", ProfileData)
class ModelField(NamedTuple):
name: str
typ: Any
required: bool
can_create: bool
can_edit: bool
async def event_log(*args, **kwargs):
# Post the given message to the configured event log, if set
event_log_url = conf.api.get('EVENTLOG')
if event_log_url:
async with ClientSession() as session:
webhook = Webhook.from_url(event_log_url, session=session)
await webhook.send(**kwargs)

304
src/routes/specimens.py Normal file
View File

@@ -0,0 +1,304 @@
"""
- `/specimens` with `GET` and `POST`
- `/specimens/{specimen_id}` with `PATCH` and `DELETE`
- `/specimens/{specimen_id}/owner` which is passed to `/users/{user_id}`
"""
import logging
from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
from aiohttp import web
from data import Condition, condition
from data.conditions import NULL
from data.queries import JOINTYPE
from datamodels import DataModel
from .lib import ModelField, datamodelsv, dbvar
if TYPE_CHECKING:
from .users import UserCreateParams, UserPayload, User
routes = web.RouteTableDef()
logger = logging.getLogger(__name__)
class SpecimenPayload(TypedDict):
specimen_id: int
owner_id: int
owner: 'UserPayload'
born_at: str
forgotten_at: Optional[str]
class SpecimenCreateParamsReq(TypedDict, total=True):
owner_id: int
class SpecimenCreateParams(SpecimenCreateParamsReq, total=False):
owner: 'UserCreateParams'
born_at: str
forgotten_at: str
class SpecimenEditParams(TypedDict, total=False):
owner_id: int
forgotten_at: Optional[str]
fields = [
ModelField('specimen_id', int, False, False, False),
ModelField('owner_id', int, False, True, True),
ModelField('owner', 'UserPayload', False, True, False),
ModelField('born_at', str, False, True, False),
ModelField('forgotten_at', str, False, True, True),
]
req_fields = {field.name for field in fields if field.required}
edit_fields = {field.name for field in fields if field.can_edit}
create_fields = {field.name for field in fields if field.can_create}
class Specimen:
def __init__(self, app: web.Application, row: DataModel.Specimen):
self.app = app
self.data = app[datamodelsv]
self.row = row
@classmethod
async def validate_create_params(cls, params):
if 'owner_id' not in params and 'owner' not in params:
raise web.HTTPBadRequest(text="One of 'owner' or 'owner_id' must be supplied to create Specimen.")
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to specimen creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"Specimen params missing required key '{missing}'")
@classmethod
async def fetch_from_id(cls, app: web.Application, spec_id: int) -> Optional[Self]:
data = app[datamodelsv]
row = await data.Specimen.fetch(int(spec_id))
return cls(app, row) if row is not None else None
@classmethod
async def query(
cls,
app: web.Application,
specimen_id: Optional[str] = None,
owner_id: Optional[str] = None,
born_after: Optional[str] = None,
born_before: Optional[str] = None,
forgotten: Optional[str] = None,
forgotten_after: Optional[str] = None,
forgotten_before: Optional[str] = None,
) -> List[Self]:
data = app[datamodelsv]
Spec = data.Specimen
conds = []
if specimen_id is not None:
conds.append(Spec.specimen_id == int(specimen_id))
if owner_id is not None:
conds.append(Spec.owner_id == int(owner_id))
if born_after is not None:
bafter = datetime.fromisoformat(born_after)
conds.append(Spec.born_at >= bafter)
if born_before is not None:
bbefore = datetime.fromisoformat(born_before)
conds.append(Spec.born_at <= bbefore)
if forgotten_after is not None:
fafter = datetime.fromisoformat(forgotten_after)
conds.append(Spec.forgotten_at >= fafter)
if forgotten_before is not None:
fbefore = datetime.fromisoformat(forgotten_before)
conds.append(Spec.forgotten_at <= fbefore)
if forgotten is not None:
if forgotten.lower() in ('1', 'true'):
conds.append(Spec.forgotten_at != NULL)
elif forgotten.lower() in ('0', 'false'):
conds.append(Spec.forgotten_at == NULL)
rows = await Spec.fetch_where(*conds).order_by(Spec.born_at)
return [cls(app, row) for row in rows]
@classmethod
async def create(cls, app: web.Application, **kwargs: Unpack[SpecimenCreateParams]) -> Self:
"""
Create a new specimen from the given data.
This will create the provided 'owner' if required.
"""
from .users import User
create_args = {}
if 'owner' in kwargs:
# Create owner and set owner_id
owner_args = kwargs['owner']
await User.validate_create_params(owner_args)
owner = await User.create(app, **owner_args)
owner_id = owner.row.user_id
if 'owner_id' in kwargs and not kwargs['owner_id'] == owner_id:
raise web.HTTPBadRequest(text="Provided `owner_id` does not match provided `owner`.")
else:
owner_id = int(kwargs['owner_id'])
create_args['owner_id'] = owner_id
if 'born_at' in kwargs:
create_args['born_at'] = datetime.fromisoformat(kwargs['born_at'])
if 'forgotten_at' in kwargs:
create_args['forgotten_at'] = datetime.fromisoformat(kwargs['forgotten_at'])
data = app[datamodelsv]
logger.info(f"Creating Specimen with {create_args=}")
row = await data.Specimen.create(**create_args)
return cls(app, row)
async def edit(self, **kwargs: Unpack[SpecimenEditParams]):
row = self.row
edit_args = {}
if 'owner_id' in kwargs:
edit_args['owner_id'] = kwargs['owner_id']
# TODO: We should probably check that the specified owner exists
if 'forgotten_at' in kwargs:
forg = kwargs['forgotten_at']
if forg is None:
# Allows unsetting the forgotten date
# This may error if the user already had a live specimen
edit_args['forgotten_at'] = None
else:
edit_args['forgotten_at'] = datetime.fromisoformat(forg)
if edit_args:
logger.info(f"Updating specimen {row=} with {kwargs}")
await row.update(**edit_args)
async def delete(self) -> SpecimenPayload:
payload = await self.prepare()
await self.row.delete()
return payload
async def get_owner(self):
from .users import User
return await User.fetch_from_id(self.app, self.row.owner_id)
async def prepare(self) -> SpecimenPayload:
owner = await self.get_owner()
if owner is None:
raise ValueError("Specimen Owner does not exist! This should never happen!")
results: SpecimenPayload = {
'specimen_id': self.row.specimen_id,
'owner_id': self.row.owner_id,
'owner': await owner.prepare(),
'born_at': self.row.born_at.isoformat(),
'forgotten_at': self.row.forgotten_at.isoformat() if self.row.forgotten_at else None
}
return results
@routes.view('/specimens')
@routes.view('/specimens/', name='specimens')
class SpecimensView(web.View):
async def post(self):
request = self.request
params = await request.json()
for key in create_fields:
if key in request:
params.setdefault(key, request[key])
await Specimen.validate_create_params(params)
logger.info(f"Creating a new Specimen with args: {params=}")
spec = await Specimen.create(self.request.app, **params)
logger.debug(f"Created specimen: {spec!r}")
payload = await spec.prepare()
return web.json_response(payload)
async def get(self):
request = self.request
filter_params = {}
keys = [
'specimen_id', 'owner_id', 'born_after', 'born_before',
'forgotten', 'forgotten_after', 'forgotten_before'
]
for key in keys:
value = request.query.get(key, request.get(key, None))
filter_params[key] = value
logger.info(f"Querying specimens with params: {filter_params=}")
specs = await Specimen.query(request.app, **filter_params)
payload = [await spec.prepare() for spec in specs]
return web.json_response(payload)
@routes.view('/specimens/{specimen_id}')
@routes.view('/specimens/{specimen_id}/', name='specimen')
class SpecimenView(web.View):
async def resolve_specimen(self):
request = self.request
spec_id = request.match_info['specimen_id']
spec = await Specimen.fetch_from_id(request.app, int(spec_id))
if spec is None:
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
return spec
async def get(self):
spec = await self.resolve_specimen()
logger.info(f"Received GET for specimen {spec=}")
payload = await spec.prepare()
return web.json_response(payload)
async def patch(self):
spec = await self.resolve_specimen()
params = await self.request.json()
edit_data = {}
for key, value in params.items():
if key not in edit_fields:
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Specimen!")
edit_data[key] = value
for key in edit_fields:
if key in self.request:
edit_data.setdefault(key, self.request[key])
logger.info(f"Received PATCH for specimen {spec} with params: {params}")
await spec.edit(**edit_data)
payload = await spec.prepare()
return web.json_response(payload)
async def delete(self):
spec = await self.resolve_specimen()
logger.info(f"Received DELETE for specimen {spec}")
payload = await spec.delete()
return web.json_response(payload)
@routes.route('*', "/specimens/{specimen_id}/owner")
@routes.route('*', "/specimens/{specimen_id}/owner{tail:/.*}")
async def specimen_owner_route(request: web.Request):
spec_id = int(request.match_info['specimen_id'])
spec = await Specimen.fetch_from_id(request.app, spec_id)
if spec is None:
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
tail = request.match_info.get('tail', '')
new_path = "/users/{user_id}".format(user_id=spec.row.owner_id) + tail
logger.info(f"Redirecting {request=} to {new_path}")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = spec.row.owner_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()

262
src/routes/stamps.py Normal file
View File

@@ -0,0 +1,262 @@
from typing import Any, NamedTuple, Optional, TypedDict, Unpack, reveal_type
from aiohttp import web
from data.database import Database
from datamodels import DataModel
from .lib import datamodelsv, ModelField
routes = web.RouteTableDef()
class StampPayload(TypedDict):
stamp_id: int
document_id: int
stamp_type: str
pos_x: int
pos_y: int
rotation: float
class StampCreateParamsReq(TypedDict, total=True):
document_id: int
stamp_type: str
pos_x: int
pos_y: int
rotation: float
class StampCreateParams(StampCreateParamsReq, total=False):
pass
class StampEditParams(TypedDict, total=False):
document_id: int
stamp_type: str
pos_x: int
pos_y: int
rotation: float
fields = [
ModelField('stamp_id', int, False, False, False),
ModelField('document_id', int, True, True, True),
ModelField('stamp_type', int, True, True, True),
ModelField('pos_x', int, True, True, True),
ModelField('pos_y', int, True, True, True),
ModelField('rotation', float, True, True, True),
]
req_fields = {field.name for field in fields if field.required}
edit_fields = {field.name for field in fields if field.can_edit}
create_fields = {field.name for field in fields if field.can_create}
class Stamp:
def __init__(self, app: web.Application, row: DataModel.DocumentStamp):
self.app = app
self.data = app[datamodelsv]
self.row = row
@classmethod
async def fetch_from_id(cls, app: web.Application, stamp_id: int) -> Optional['Stamp']:
stamp = await app[datamodelsv].DocumentStamp.fetch(stamp_id)
if stamp is None:
return None
return cls(app, stamp)
@classmethod
async def query(
cls,
app: web.Application,
stamp_id: Optional[int] = None,
document_id: Optional[int] = None,
stamp_type: Optional[str] = None,
):
data = app[datamodelsv]
query_args = {}
if stamp_id is not None:
query_args['stamp_id'] = int(stamp_id)
if document_id is not None:
query_args['document_id'] = int(document_id)
if stamp_type is not None:
typerows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
typeids = [row.stamp_type_id for row in typerows]
if not typeids:
return []
query_args['stamp_type'] = typeids
results = await data.DocumentStamp.table.fetch_rows_where(**query_args)
return [cls(app, row) for row in sorted(results, key=lambda row:row.stamp_id)]
@classmethod
async def create(
cls,
app: web.Application,
**kwargs: Unpack[StampCreateParams]
):
data = app[datamodelsv]
stamp_type = kwargs['stamp_type']
# Get the stamp_type
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
if not rows:
# Create the stamp type
row = await data.StampType.create(stamp_type_name=stamp_type)
else:
row = rows[0]
stamprow = await data.DocumentStamp.create(
document_id=kwargs['document_id'],
stamp_type=row.stamp_type_id,
position_x=int(kwargs['pos_x']),
position_y=int(kwargs['pos_y']),
rotation=float(kwargs['rotation'])
)
return cls(app, stamprow)
async def prepare(self) -> StampPayload:
typerow = await self.data.StampType.fetch(self.row.stamp_type)
assert typerow is not None
results: StampPayload = {
'stamp_id': self.row.stamp_id,
'document_id': self.row.document_id,
'stamp_type': typerow.stamp_type_name,
'pos_x': self.row.position_x,
'pos_y': self.row.position_y,
'rotation': self.row.rotation,
}
return results
async def edit(
self,
**kwargs: Unpack[StampEditParams]
):
data = self.data
row = self.row
edit_args = {}
if stamp_type := kwargs.get('stamp_type'):
# Get the stamp_type
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
if not rows:
# Create the stamp type
row = await data.StampType.create(stamp_type_name=stamp_type)
else:
row = rows[0]
edit_args['stamp_type'] = row.stamp_type_id
simple_keys = {
'document_id': 'document_id',
'pos_x': 'position_x',
'pos_y': 'position_y',
'rotation': 'rotation'
}
for editkey, datakey in simple_keys.items():
if editkey in kwargs:
edit_args[datakey] = kwargs[editkey]
await self.row.update(
**edit_args
)
async def delete(self) -> StampPayload:
payload = await self.prepare()
await self.row.delete()
return payload
@routes.view('/stamps')
@routes.view('/stamps/', name='stamps')
class StampsView(web.View):
async def get(self):
request = self.request
# Decode request parameters to filter args
filter_params = {}
keys = ['stamp_id', 'document_id', 'stamp_type']
for key in keys:
if key in request.query:
filter_params[key] = request.query[key]
elif key in request:
filter_params[key] = request[key]
stamps = await Stamp.query(request.app, **filter_params)
payload = [await stamp.prepare() for stamp in stamps]
return web.json_response(payload)
async def create_one(self, params: StampCreateParams):
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to stamp creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"Stamp params missing required key '{missing}'.")
# This still doesn't guarantee that the values are of the correct type, but good enough.
stamp = await Stamp.create(self.request.app, **params)
return stamp
async def post(self):
request = self.request
params = await request.json()
for key in create_fields:
if key in request:
params.setdefault(key, request[key])
stamp = await self.create_one(params)
payload = await stamp.prepare()
return web.json_response(payload)
async def put(self):
request = self.request
from_request = {key: request[key] for key in create_fields if key in request}
argslist = await request.json()
payloads = []
for args in argslist:
stamp = await self.create_one(from_request | args)
payload = await stamp.prepare()
payloads.append(payload)
return web.json_response(payloads)
@routes.view('/stamps/{stamp_id}')
@routes.view('/stamps/{stamp_id}/', name='stamp')
class StampView(web.View):
async def resolve_stamp(self):
request = self.request
stamp_id = request.match_info['stamp_id']
stamp = await Stamp.fetch_from_id(request.app, int(stamp_id))
if stamp is None:
raise web.HTTPNotFound(text="No stamp exists with the given ID.")
return stamp
async def get(self):
stamp = await self.resolve_stamp()
payload = await stamp.prepare()
return web.json_response(payload)
async def patch(self):
stamp = await self.resolve_stamp()
params = await self.request.json()
edit_data = {}
for key, value in params.items():
if key not in edit_fields:
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Stamp!")
edit_data[key] = value
for key in edit_fields:
if key in self.request:
edit_data.setdefault(key, self.request[key])
await stamp.edit(**edit_data)
payload = await stamp.prepare()
return web.json_response(payload)
async def delete(self):
stamp = await self.resolve_stamp()
payload = await stamp.delete()
return web.json_response(payload)

223
src/routes/transactions.py Normal file
View File

@@ -0,0 +1,223 @@
import logging
from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
from aiohttp import web
from data import Condition, condition
from data.conditions import NULL
from data.queries import JOINTYPE
from datamodels import DataModel
from .lib import ModelField, datamodelsv, dbvar
if TYPE_CHECKING:
from .users import UserCreateParams, UserPayload, User
routes = web.RouteTableDef()
logger = logging.getLogger(__name__)
class TransactionPayload(TypedDict):
transaction_id: int
user_id: int
user: 'UserPayload'
amount: int
description: str
reference: Optional[str]
created_at: str
class TransactionCreateParamsReq(TypedDict, total=True):
user_id: int
amount: int
description: str
class TransactionCreateParams(TransactionCreateParamsReq, total=False):
reference: str
fields = [
ModelField('transaction_id', int, False, False, False),
ModelField('user_id', int, True, True, False),
ModelField('amount', int, True, True, False),
ModelField('description', str, True, True, False),
ModelField('reference', str, False, True, False),
ModelField('created_at', str, False, False, False),
]
req_fields = {field.name for field in fields if field.required}
edit_fields = {field.name for field in fields if field.can_edit}
create_fields = {field.name for field in fields if field.can_create}
class Transaction:
def __init__(self, app: web.Application, row: DataModel.Transaction):
self.app = app
self.data = app[datamodelsv]
self.row = row
@classmethod
async def validate_create_params(cls, params):
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to transaction creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"Transaction params missing required key '{missing}'")
@classmethod
async def fetch_from_id(cls, app: web.Application, tid: int) -> Optional[Self]:
data = app[datamodelsv]
row = await data.Transaction.fetch(int(tid))
return cls(app, row) if row is not None else None
@classmethod
async def query(
cls,
app: web.Application,
transaction_id: Optional[int] = None,
user_id: Optional[int] = None,
reference: Optional[str] = None,
created_before: Optional[str] = None,
created_after: Optional[str] = None,
) -> List[Self]:
data = app[datamodelsv]
TXN = data.Transaction
conds = []
if transaction_id is not None:
conds.append(TXN.transaction_id == int(transaction_id))
if user_id is not None:
conds.append(TXN.user_id == int(user_id))
if reference is not None:
conds.append(TXN.reference == reference)
if created_before is not None:
cbefore = datetime.fromisoformat(created_before)
conds.append(TXN.created_at <= cbefore)
if created_after is not None:
cafter = datetime.fromisoformat(created_after)
conds.append(TXN.created_at >= cafter)
rows = await TXN.fetch_where(*conds).order_by(TXN.created_at)
return [cls(app, row) for row in rows]
@classmethod
async def create(cls, app: web.Application, **kwargs: Unpack[TransactionCreateParams]) -> Self:
data = app[datamodelsv]
create_args = {}
for key in ('user_id', 'description', 'amount', 'reference'):
create_args[key] = kwargs.get(key)
logger.info(f"Creating Transaction with {create_args=}")
row = await data.Transaction.create(**create_args)
return cls(app, row)
async def edit(self):
raise ValueError("Transactions are immutable.")
async def delete(self):
raise ValueError("Transactions cannot be deleted directly.")
async def get_user(self):
from .users import User
return await User.fetch_from_id(self.app, self.row.user_id)
async def prepare(self) -> TransactionPayload:
user = await self.get_user()
if user is None:
raise ValueError("Transaction owner does not exist! This cannot happen.")
results: TransactionPayload = {
'transaction_id': self.row.transaction_id,
'user_id': self.row.user_id,
'user': await user.prepare(details=False),
'amount': self.row.amount,
'description': self.row.description,
'reference': self.row.reference,
'created_at': self.row.created_at.isoformat(),
}
return results
@routes.view('/transactions')
@routes.view('/transactions/', name='transactions')
class TransactionsView(web.View):
async def post(self):
request = self.request
params = await request.json()
for key in create_fields:
if key in request:
params.setdefault(key, request[key])
await Transaction.validate_create_params(params)
logger.info(f"Creating a new Transaction with args: {params=}")
txn = await Transaction.create(self.request.app, **params)
logger.debug(f"Created transaction: {txn!r}")
payload = await txn.prepare()
return web.json_response(payload)
async def get(self):
request = self.request
filter_params = {}
keys = [
'transaction_id', 'user_id', 'reference',
'created_before', 'created_after',
]
for key in keys:
value = request.query.get(key, request.get(key, None))
filter_params[key] = value
logger.info(f"Querying transactions with params: {filter_params=}")
txns = await Transaction.query(request.app, **filter_params)
payload = [await txn.prepare() for txn in txns]
return web.json_response(payload)
@routes.view('/transactions/{transaction_id}')
@routes.view('/transactions/{transaction_id}/', name='transaction_id')
class TransactionView(web.View):
async def resolve_transaction(self):
request = self.request
txn_id = request.match_info['transaction_id']
txn = await Transaction.fetch_from_id(request.app, int(txn_id))
if txn is None:
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
return txn
async def get(self):
txn = await self.resolve_transaction()
logger.info(f"Received GET for transaction {txn=}")
payload = await txn.prepare()
return web.json_response(payload)
async def patch(self):
raise web.HTTPBadRequest(text="Transactions are immutable and cannot be edited.")
async def delete(self):
raise web.HTTPBadRequest(text="Transactions cannot be individually deleted.")
@routes.route('*', "/transactions/{transaction_id}/user")
@routes.route('*', "/transactions/{transaction_id}/user{tail:/.*}")
async def transaction_user_route(request: web.Request):
txn_id = int(request.match_info['transaction_id'])
txn = await Transaction.fetch_from_id(request.app, txn_id)
if txn is None:
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
tail = request.match_info.get('tail', '')
new_path = "/users/{user_id}".format(user_id=txn.row.user_id) + tail
logger.info(f"Redirecting {request=} to {new_path}")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = txn.row.user_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()

461
src/routes/users.py Normal file
View File

@@ -0,0 +1,461 @@
"""
- `/users` with `POST`, `GET`, `PATCH`, `DELETE`
- `/users/{user_id}` with `GET`, `PATCH`, `DELETE`
- `/users/{user_id}/events` which is passed to `/events`
- `/users/{user_id}/specimen` which is passed to `/specimens/{specimen_id}`
- `/users/{user_id}/specimens` which is passed to `/specimens`
- `/users/{user_id}/wallet` with `GET`
- `/users/{user_id}/transactions` which is passed to `/transactions`
"""
import logging
from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
from aiohttp import web
import discord
from data import Condition, condition
from data.conditions import NULL
from data.queries import JOINTYPE
from datamodels import DataModel
from modules.profiles.data import ProfileData
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
from .specimens import Specimen, SpecimenPayload
routes = web.RouteTableDef()
logger = logging.getLogger(__name__)
class UserPayload(TypedDict):
user_id: int
twitch_id: Optional[str]
name: Optional[str]
preferences: Optional[str]
created_at: str
class UserDetailsPayload(UserPayload):
specimen: Optional[SpecimenPayload]
inventory: List # TODO
wallet: int
class UserCreateParamsReq(TypedDict, total=True):
twitch_id: str
name: str
class UserCreateParams(UserCreateParamsReq, total=False):
preferences: str
class UserEditParams(TypedDict, total=False):
name: Optional[str]
preferences: Optional[str]
fields = [
ModelField('user_id', int, False, False, False),
ModelField('twitch_id', str, True, True, False),
ModelField('name', str, True, True, True),
ModelField('preferences', str, False, True, True),
ModelField('created_at', str, False, False, False),
]
req_fields = {field.name for field in fields if field.required}
edit_fields = {field.name for field in fields if field.can_edit}
create_fields = {field.name for field in fields if field.can_create}
class User:
def __init__(self, app: web.Application, row: DataModel.Dreamer):
self.app = app
self.data = app[datamodelsv]
self.profile_data = app[profiledatav]
self.row = row
self._pref_row: Optional[DataModel.UserPreferences] = None
async def get_prefs(self) -> DataModel.UserPreferences:
if self._pref_row is None:
self._pref_row = await self.data.UserPreferences.fetch_or_create(self.row.user_id)
return self._pref_row
@classmethod
async def validate_create_params(cls, params):
if extra := next((key for key in params if key not in create_fields), None):
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to user creation.")
if missing := next((key for key in req_fields if key not in params), None):
raise web.HTTPBadRequest(text=f"User params missing required key '{missing}'")
@classmethod
async def fetch_from_id(cls, app: web.Application, user_id: int):
data = app[datamodelsv]
row = await data.Dreamer.fetch(int(user_id))
return cls(app, row) if row is not None else None
@classmethod
async def query(
cls,
app: web.Application,
user_id: Optional[str] = None,
twitch_id: Optional[str] = None,
name: Optional[str] = None,
created_before: Optional[str] = None,
created_after: Optional[str] = None,
) -> List[Self]:
data = app[datamodelsv]
Dreamer = data.Dreamer
conds = []
if user_id is not None:
conds.append(Dreamer.user_id == int(user_id))
if twitch_id is not None:
conds.append(Dreamer.twitch_id == twitch_id)
if name is not None:
conds.append(Dreamer.name == name)
if created_before is not None:
cbefore = datetime.fromisoformat(created_before)
conds.append(Dreamer.created_at <= cbefore)
if created_after is not None:
cafter = datetime.fromisoformat(created_after)
conds.append(Dreamer.created_at >= cafter)
rows = await Dreamer.fetch_where(*conds).order_by(Dreamer.created_at)
return [cls(app, row) for row in rows]
@classmethod
async def create(cls, app: web.Application, **kwargs: Unpack[UserCreateParams]):
"""
Create a new User from the provided data.
This creates the associated UserProfile, TwitchProfile, and UserPreferences if needed.
If a profile already exists, this does *not* error.
Instead, this updates the existing User with the new data.
"""
data = app[datamodelsv]
twitch_id = kwargs['twitch_id']
name = kwargs['name']
prefs = kwargs.get('preferences')
# Quick sanity check on the twitch id
if not twitch_id or not twitch_id.isdigit():
raise web.HTTPBadRequest(text="Invalid 'twitch_id' passed to user creation!")
# First check if the profile already exists by querying the Dreamer database
edited = 0 # 0 means not edited, 1 means created, 2 means modified
rows = await data.Dreamer.fetch_where(twitch_id=twitch_id)
if rows:
logger.debug(f"Updating Dreamer for {twitch_id=} with {kwargs}")
dreamer = rows[0]
# A twitch profile with this twitch_id already exists
# But it is possible UserPreferences don't exist
if dreamer.preferences is None and dreamer.name is None:
await data.UserPreferences.fetch_or_create(dreamer.user_id, twitch_name=name, preferences=prefs)
dreamer = await dreamer.refresh()
edited = 2
# Now compare the existing data against the provided data and update if needed
if name != dreamer.name:
q = data.UserPreferences.table.update_where(profileid=dreamer.user_id)
q.set(twitch_name=name)
if prefs is not None:
q.set(preferences=prefs)
await q
dreamer = await dreamer.refresh()
edited = 2
else:
# Create from scratch
logger.info(f"Creating Dreamer for {twitch_id=} with {kwargs}")
# TODO: Should be in a transaction.. actually let's add transactions to the middleware..
profile_data = app[profiledatav]
user_profile = await profile_data.UserProfileRow.create(nickname=name)
await profile_data.TwitchProfileRow.create(
profileid=user_profile.profileid,
userid=twitch_id,
)
await data.UserPreferences.create(
profileid=user_profile.profileid,
twitch_name=name,
preferences=prefs
)
dreamer = await data.Dreamer.fetch(user_profile.profileid)
assert dreamer is not None
edited = 1
self = cls(app, dreamer)
if edited == 1:
args = await self.event_log_args(title=f"User #{dreamer.user_id} created!")
await event_log(**args.send_args)
elif edited == 2:
args = await self.event_log_args(title=f"User #{dreamer.user_id} updated!")
await event_log(**args.send_args)
return self
async def edit(self, **kwargs: Unpack[UserEditParams]):
data = self.data
# We can edit the name, and preferences
prefs = await self.get_prefs()
update_args = {}
if 'name' in kwargs:
update_args['twitch_name'] = kwargs['name']
if 'preferences' in kwargs:
update_args['preferences'] = kwargs['preferences']
if update_args:
logger.info(f"Updating dreamer {self.row=} with {kwargs}")
await prefs.update(**update_args)
args = await self.event_log_args(title=f"User #{self.row.user_id} updated!")
await event_log(**args.send_args)
async def delete(self) -> UserDetailsPayload:
payload = await self.prepare(details=True)
# This will cascade to all other data the user has
await self.profile_data.UserProfileRow.table.delete_where(profileid=self.row.user_id)
# Make sure we take the user out of cache
await self.row.refresh()
return payload
async def get_wallet(self):
query = self.data.Transaction.table.select_where(user_id=self.row.user_id)
query.select(wallet="SUM(amount)")
query.with_no_adapter()
results = await query
return results[0]['wallet']
async def get_specimen(self) -> Optional[Specimen]:
data = self.data
active_specrows = await data.Specimen.fetch_where(
owner_id=self.row.user_id,
forgotten_at=NULL
)
if active_specrows:
row = active_specrows[0]
spec = Specimen(self.app, row)
else:
spec = None
return spec
async def get_inventory(self):
return []
async def event_log_args(self, **kwargs) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at, **kwargs)
embed.set_footer(text='Created At')
# TODO: We could add wallet, specimen, and inventory info here too
return MessageArgs(embed=embed)
async def tabulate(self):
"""
Present the User as a discord-readable table.
"""
table = {
'user_id': f"`{self.row.user_id}`",
'twitch_id': f"`{self.row.twitch_id}`" if self.row.twitch_id else 'No Twitch linked',
'name': f"`{self.row.name}`",
'preferences': f"`{self.row.preferences}`",
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
return tabulate(*table.items())
@overload
async def prepare(self, details: Literal[True]=True) -> UserDetailsPayload:
...
@overload
async def prepare(self, details: Literal[False]=False) -> UserPayload:
...
async def prepare(self, details=False) -> UserPayload | UserDetailsPayload:
# Since we are working with view rows, make sure we refresh
row = self.row
await row.refresh()
base_user: UserPayload = {
'user_id': row.user_id,
'twitch_id': str(row.twitch_id) if row.twitch_id else None,
'name': row.name,
'preferences': row.preferences,
'created_at': row.created_at.isoformat(),
}
if details:
# Now add details
specimen = await self.get_specimen()
sp_payload = await specimen.prepare() if specimen is not None else None
inventory = [await item.prepare() for item in await self.get_inventory()]
user: UserPayload = base_user | {
'specimen': sp_payload,
'inventory': inventory,
'wallet': await self.get_wallet(),
}
else:
user = base_user
logger.debug(f"User prepared: {user}")
return user
@routes.view('/users')
@routes.view('/users/', name='users')
class UsersView(web.View):
async def post(self):
request = self.request
params = await request.json()
for key in create_fields:
if key in request:
params.setdefault(key, request[key])
await User.validate_create_params(params)
logger.info(f"Creating a new user with args: {params=}")
user = await User.create(self.request.app, **params)
logger.debug(f"Created user: {user!r}")
payload = await user.prepare(details=True)
return web.json_response(payload)
async def get(self):
request = self.request
filter_params = {}
keys = [
'user_id', 'twitch_id', 'name', 'created_before', 'created_after',
]
for key in keys:
value = request.query.get(key, request.get(key, None))
filter_params[key] = value
logger.info(f"Querying users with params: {filter_params=}")
users = await User.query(request.app, **filter_params)
payload = [await user.prepare(details=True) for user in users]
return web.json_response(payload)
@routes.view('/users/{user_id}')
@routes.view('/users/{user_id}/', name='user')
class UserView(web.View):
async def resolve_user(self):
request = self.request
user_id = request.match_info['user_id']
user = await User.fetch_from_id(request.app, int(user_id))
if user is None:
raise web.HTTPNotFound(text="No user exists with the given ID.")
return user
async def get(self):
user = await self.resolve_user()
logger.info(f"Received GET for user {user=}")
payload = await user.prepare(details=True)
return web.json_response(payload)
async def patch(self):
user = await self.resolve_user()
params = await self.request.json()
edit_data = {}
for key, value in params.items():
if key not in edit_fields:
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
edit_data[key] = value
for key in edit_fields:
if key in self.request:
edit_data.setdefault(key, self.request[key])
logger.info(f"Received PATCH for user {user} with params: {params}")
await user.edit(**edit_data)
payload = await user.prepare(details=True)
return web.json_response(payload)
async def delete(self):
user = await self.resolve_user()
logger.info(f"Received DELETE for user {user}")
payload = await user.delete()
return web.json_response(payload)
@routes.route('*', "/users/{user_id}{tail:/events}")
@routes.route('*', "/users/{user_id}{tail:/events/.*}")
@routes.route('*', "/users/{user_id}{tail:/transactions}")
@routes.route('*', "/users/{user_id}{tail:/transactions/.*}")
@routes.route('*', "/users/{user_id}{tail:/specimens}")
@routes.route('*', "/users/{user_id}{tail:/specimens/.*}")
async def user_prefix_routes(request: web.Request):
user_id = int(request.match_info['user_id'])
user = await User.fetch_from_id(request.app, user_id)
if user is None:
raise web.HTTPNotFound(text="No user exists with the given ID.")
new_path = request.match_info['tail']
logger.info(f"Redirecting {request=} to {new_path=} and setting {user_id=}")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = user_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()
@routes.route('*', "/users/{user_id}/specimen")
@routes.route('*', "/users/{user_id}/specimen{tail:/.*}")
async def user_specimen_route(request: web.Request):
user_id = int(request.match_info['user_id'])
user = await User.fetch_from_id(request.app, user_id)
if user is None:
raise web.HTTPNotFound(text="No user exists with the given ID.")
tail = request.match_info.get('tail', '')
specimen = await user.get_specimen()
if request.method == 'POST' and not tail.strip('/'):
if specimen is None:
# Redirect to POST /specimens
# TODO: Would be nicer to use named handler here
new_path = '/specimens'
logger.info(f"Redirecting {request=} to POST /specimens")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = user_id
new_request['owner_id'] = user_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
return await match_info.handler(new_request)
else:
raise web.HTTPBadRequest(text="This user already has an active specimen!")
elif specimen is None:
raise web.HTTPNotFound(text="This user has no active specimen.")
else:
specimen_id = specimen.row.specimen_id
# Redirect to POST /specimens/{specimen_id}/...
new_path = f"/specimens/{specimen_id}".format(specimen_id=specimen_id) + tail
logger.info(f"Redirecting {request=} to {new_path}")
new_request = request.clone(rel_url=new_path)
new_request['user_id'] = user_id
new_request['owner_id'] = user_id
new_request['specimen_id'] = specimen_id
match_info = await request.app.router.resolve(new_request)
new_request._match_info = match_info
match_info.current_app = request.app
if match_info.handler:
return await match_info.handler(new_request)
else:
logger.info(f"Could not find handler matching {new_request}")
raise web.HTTPNotFound()
@routes.route('GET', "/users/{user_id}/wallet")
@routes.route('GET', "/users/{user_id}/wallet/")
async def user_wallet_route(request: web.Request):
user_id = int(request.match_info['user_id'])
user = await User.fetch_from_id(request.app, user_id)
if user is None:
raise web.HTTPNotFound(text="No user exists with the given ID.")
wallet = await user.get_wallet()
return web.json_response(wallet)

9
src/twitch/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TwitchAuthCog
async def setup(bot):
await bot.add_cog(TwitchAuthCog(bot))

50
src/twitch/authclient.py Normal file
View File

@@ -0,0 +1,50 @@
"""
Testing client for the twitch AuthServer.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
import asyncio
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.type import AuthScope
from meta.config import conf
URI = "http://localhost:3000/twiauth/confirm"
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
async def main():
# Load in client id and secret
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
url = auth.return_auth_url()
# Post url to user
print(url)
# Send listen request to server
# Wait for listen request
async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
await ws.send_json({'state': auth.state})
result = await ws.receive_json()
# Hopefully get back code, print the response
print(f"Recieved: {result}")
# Authorise with code and client details
tokens = await auth.authenticate(user_token=result['code'])
if tokens:
token, refresh = tokens
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
print(f"Authorised!")
if __name__ == '__main__':
asyncio.run(main())

86
src/twitch/authserver.py Normal file
View File

@@ -0,0 +1,86 @@
import logging
import uuid
import asyncio
from contextvars import ContextVar
import aiohttp
from aiohttp import web
logger = logging.getLogger(__name__)
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
class AuthServer:
def __init__(self):
self.listeners = {}
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
args = request.query
if 'state' not in args:
raise web.HTTPBadRequest(text="No state provided.")
if args['state'] not in self.listeners:
raise web.HTTPBadRequest(text="Invalid state.")
self.listeners[args['state']].set_result(dict(args))
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
_reqid = str(uuid.uuid1())
reqid.set(_reqid)
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get the listen request data
try:
listen_req = await ws.receive_json(timeout=60)
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
if 'state' not in listen_req:
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
raise web.HTTPBadRequest(text="Listen request must include state string.")
elif listen_req['state'] in self.listeners:
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
raise web.HTTPBadRequest(text="Invalid state string.")
except ValueError:
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except TypeError:
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
except Exception:
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
raise web.HTTPInternalServerError()
try:
fut = self.listeners[listen_req['state']] = asyncio.Future()
result = await asyncio.wait_for(fut, timeout=120)
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
finally:
self.listeners.pop(listen_req['state'], None)
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
await ws.send_json(result)
await ws.close()
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
return ws
def main(argv):
app = web.Application()
server = AuthServer()
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
app.router.add_get("/twiauth/listen", server.handle_listen_request)
logger.info("App setup and configured. Starting now.")
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
if __name__ == '__main__':
import sys
main(sys.argv)

113
src/twitch/cog.py Normal file
View File

@@ -0,0 +1,113 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.twitch import Twitch
from twitchAPI.type import AuthScope
from twitchio.ext import commands
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot
from meta.LionContext import LionContext
from twitch.userflow import UserAuthFlow
from utils.lib import utc_now
from . import logger
from .data import TwitchAuthData
class TwitchAuthCog(LionCog):
DEFAULT_SCOPES = []
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
self.client_cache = {}
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: str):
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow is None:
# TODO: Some user authentication error
self.client_cache.pop(userid, None)
raise ValueError("Requested user is not authenticated.")
if (twitch := self.client_cache.get(userid)) is None:
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
authscopes = [AuthScope(scope) for scope in scopes]
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
self.client_cache[userid] = twitch
return twitch
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
Checks whether the given userid is authorised.
If 'scopes' is given, will also check the user has all of the given scopes.
"""
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
desired = {scope.value for scope in scopes}
has_auth = desired.issubset(has_scopes)
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
else:
has_auth = True
else:
has_auth = False
return has_auth
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
"""
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
self.client_cache.pop(userid, None)
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
return await self.start_auth(to_request)
async def start_auth(self, scopes = []):
# TODO: Work out a way to just clone the current twitch object
# Or can we otherwise build UserAuthenticator without app auth?
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
await flow.setup()
return flow
# ----- Commands -----
@cmds.hybrid_command(name='auth')
async def cmd_auth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
flow = await self.start_auth()
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")
@cmds.hybrid_command(name='modauth')
async def cmd_modauth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
scopes = [
AuthScope.MODERATOR_READ_FOLLOWERS,
AuthScope.CHANNEL_READ_REDEMPTIONS,
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
]
flow = await self.start_auth(scopes=scopes)
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

79
src/twitch/data.py Normal file
View File

@@ -0,0 +1,79 @@
import datetime as dt
from data import Registry, RowModel, Table
from data.columns import Integer, String, Timestamp
class TwitchAuthData(Registry):
class UserAuthRow(RowModel):
"""
Schema
------
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
"""
_tablename_ = 'twitch_user_auth'
_cache_ = {}
userid = Integer(primary=True)
access_token = String()
refresh_token = String()
expires_at = Timestamp()
obtained_at = Timestamp()
@classmethod
async def update_user_auth(
cls, userid: str, token: str, refresh: str,
expires_at: dt.datetime, obtained_at: dt.datetime,
scopes: list[str]
):
if cls._connector is None:
raise ValueError("Attempting to use uninitialised Registry.")
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Clear row for this userid
await cls.table.delete_where(userid=userid)
# Insert new user row
row = await cls.create(
userid=userid,
access_token=token,
refresh_token=refresh,
expires_at=expires_at,
obtained_at=obtained_at
)
# Insert new scope rows
if scopes:
await TwitchAuthData.user_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in scopes)
)
return row
@classmethod
async def get_scopes_for(cls, userid: str) -> list[str]:
"""
Get a list of scopes stored for the given user.
Will return an empty list if the user is not authenticated.
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row['scope'] for row in rows] if rows else []
"""
Schema
------
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
"""
user_scopes = Table('twitch_user_scopes')

0
src/twitch/lib.py Normal file
View File

88
src/twitch/userflow.py Normal file
View File

@@ -0,0 +1,88 @@
from typing import Optional
import datetime as dt
from aiohttp import web
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator, validate_token
from twitchAPI.type import AuthType
from twitchio.client import asyncio
from meta.errors import SafeCancellation
from utils.lib import utc_now
from .data import TwitchAuthData
from . import logger
class UserAuthFlow:
auth: UserAuthenticator
data: TwitchAuthData
auth_ws: str
def __init__(self, data, auth, auth_ws):
self.auth = auth
self.data = data
self.auth_ws = auth_ws
self._setup_done = asyncio.Event()
self._comm_task: Optional[asyncio.Task] = None
async def setup(self):
"""
Establishes websocket connection to the AuthServer,
and requests listening for the given state.
Propagates any exceptions that occur during connection setup.
"""
if self._setup_done.is_set():
raise ValueError("UserAuthFlow is already set up.")
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
await self._setup_done.wait()
if self._comm_task.done() and (exc := self._comm_task.exception()):
raise exc
async def _communicate(self):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.auth_ws) as ws:
await ws.send_json({'state': self.auth.state})
self._setup_done.set()
return await ws.receive_json()
async def run(self) -> TwitchAuthData.UserAuthRow:
if not self._setup_done.is_set():
raise ValueError("Cannot run UserAuthFlow before setup.")
if self._comm_task is None:
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
result = await self._comm_task
if result.get('error', None):
# TODO Custom auth errors
# This is only documented to occur when the user denies the auth
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
if result.get('state', None) != self.auth.state:
# This should never happen unless the authserver has its wires crossed somehow,
# or the connection has been tampered with.
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
logger.critical(
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
)
raise SafeCancellation(
"Could not complete authentication! Invalid server response."
)
# Now assume result has a valid code
# Exchange code for an auth token and a refresh token
# Ignore type here, authenticate returns None if a callback function has been given.
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
# Fetch the associated userid and basic info
v_result = await validate_token(token)
userid = v_result['user_id']
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
# Save auth data
return await self.data.UserAuthRow.update_user_auth(
userid=userid, token=token, refresh=refresh,
expires_at=expiry, obtained_at=utc_now(),
scopes=[scope.value for scope in self.auth.scopes]
)

27
src/utils/auth.py Normal file
View File

@@ -0,0 +1,27 @@
from typing import Awaitable, Callable
from aiohttp import hdrs, web
CHALLENGE = web.Response(
body="<b> 401 UNAUTHORIZED </b>",
status=401,
reason='UNAUTHORIZED',
headers={
hdrs.WWW_AUTHENTICATE: 'X-API-KEY',
hdrs.CONTENT_TYPE: 'text/html; charset=utf-8',
hdrs.CONNECTION: 'close',
},
)
def key_auth_factory(required_token: str):
"""
Creates an aiohttp middleware that ensures
the `required_token` is provided in the `X-API-KEY` header.
"""
@web.middleware
async def key_auth_middleware(request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]]) -> web.StreamResponse:
auth_header = request.headers.get('X-API-KEY')
if not auth_header or auth_header.strip() != required_token:
return CHALLENGE
else:
return await handler(request)
return key_auth_middleware