Compare commits
25 Commits
9625dec1e4
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 445eccccd6 | |||
| 6ec500ec87 | |||
| 8e2bd67efc | |||
| 092e818990 | |||
| 2bf95beaae | |||
| 250b55634d | |||
| 9e5c2f5777 | |||
| e1a1f7d4fe | |||
| c3ed48e918 | |||
| 48a01a2861 | |||
| d83709d2c2 | |||
| 7f977f90e8 | |||
| 04b6dcbc3f | |||
| c07577cc0a | |||
| dc551b34a9 | |||
| 94bc8b6c21 | |||
| aba73b8bba | |||
| 77dc90cc32 | |||
| a02cc0977a | |||
| 5efcdd6709 | |||
| 0adccaae02 | |||
| a7afa5001d | |||
| d271248812 | |||
| 8421c5359d | |||
| 2cf81c38e8 |
@@ -189,9 +189,10 @@ CREATE TABLE stamp_types (
|
||||
|
||||
CREATE TABLE documents (
|
||||
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
document_data VARCHAR NOT NULL,
|
||||
document_data TEXT NOT NULL,
|
||||
seal INTEGER NOT NULL,
|
||||
metadata TEXT
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE document_stamps (
|
||||
@@ -230,14 +231,14 @@ CREATE TABLE plain_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
||||
message TEXT NOT NULL,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE raid_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
||||
visitor_count INTEGER NOT NULL,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE cheer_events (
|
||||
@@ -246,7 +247,7 @@ CREATE TABLE cheer_events (
|
||||
amount INTEGER NOT NULL,
|
||||
cheer_type TEXT,
|
||||
message TEXT,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE subscriber_events (
|
||||
@@ -255,10 +256,37 @@ CREATE TABLE subscriber_events (
|
||||
subscribed_length INTEGER NOT NULL,
|
||||
tier INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
|
||||
CREATE VIEW event_details AS
|
||||
SELECT
|
||||
events.event_id AS event_id,
|
||||
events.user_id AS user_id,
|
||||
events.document_id AS document_id,
|
||||
events.user_name AS user_name,
|
||||
events.event_type AS event_type,
|
||||
events.occurred_at AS occurred_at,
|
||||
events.created_at AS created_at,
|
||||
plain_events.message AS plain_message,
|
||||
raid_events.visitor_count AS raid_visitor_count,
|
||||
cheer_events.amount AS cheer_amount,
|
||||
cheer_events.cheer_type AS cheer_type,
|
||||
cheer_events.message AS cheer_message,
|
||||
subscriber_events.subscribed_length AS subscriber_length,
|
||||
subscriber_events.tier AS subscriber_tier,
|
||||
subscriber_events.message AS subscriber_message,
|
||||
documents.seal AS document_seal
|
||||
FROM
|
||||
events
|
||||
LEFT JOIN plain_events USING (event_id)
|
||||
LEFT JOIN raid_events USING (event_id)
|
||||
LEFT JOIN cheer_events USING (event_id)
|
||||
LEFT JOIN subscriber_events USING (event_id)
|
||||
LEFT JOIN documents USING (document_id)
|
||||
ORDER BY events.occurred_at ASC;
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Specimens {{{
|
||||
@@ -269,6 +297,7 @@ CREATE TABLE user_specimens (
|
||||
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
forgotten_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
@@ -4,3 +4,5 @@ discord.py [voice]
|
||||
iso8601
|
||||
psycopg[pool]
|
||||
pytz
|
||||
twitchio
|
||||
twitchAPI
|
||||
|
||||
61
src/api.py
Normal file
61
src/api.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
from meta import conf
|
||||
from data import Database
|
||||
from utils.auth import key_auth_factory
|
||||
from datamodels import DataModel
|
||||
from constants import DATA_VERSION
|
||||
|
||||
from modules.profiles.data import ProfileData
|
||||
from routes import dbvar, datamodelsv, profiledatav, register_routes
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def attach_db(app: web.Application):
|
||||
db = Database(conf.data['args'])
|
||||
async with db.open():
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||
logger.critical(error)
|
||||
raise RuntimeError(error)
|
||||
|
||||
datamodel = DataModel()
|
||||
db.load_registry(datamodel)
|
||||
await datamodel.init()
|
||||
|
||||
profiledata = ProfileData()
|
||||
db.load_registry(profiledata)
|
||||
await profiledata.init()
|
||||
|
||||
app[dbvar] = db
|
||||
app[datamodelsv] = datamodel
|
||||
app[profiledatav] = profiledata
|
||||
|
||||
yield
|
||||
|
||||
|
||||
async def test(request: web.Request) -> web.Response:
|
||||
return web.Response(text="Welcome to the Dreamspace API. Please donate an important childhood memory to continue.")
|
||||
|
||||
def app_factory():
|
||||
auth = key_auth_factory(conf.API['TOKEN'])
|
||||
app = web.Application(middlewares=[auth])
|
||||
app.cleanup_ctx.append(attach_db)
|
||||
app.router.add_get('/', test)
|
||||
register_routes(app.router)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = app_factory()
|
||||
web.run_app(app, port=int(conf.API['PORT']))
|
||||
|
||||
11
src/bot.py
11
src/bot.py
@@ -4,6 +4,7 @@ import logging
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from twitchAPI.twitch import Twitch
|
||||
|
||||
from meta import LionBot, conf, sharding, appname
|
||||
from meta.app import shardname
|
||||
@@ -49,13 +50,15 @@ async def _data_monitor() -> ComponentStatus:
|
||||
|
||||
async def main():
|
||||
log_action_stack.set(("Initialising",))
|
||||
logger.info("Initialising StudyLion")
|
||||
logger.info("Initialising LionBot")
|
||||
|
||||
intents = discord.Intents.all()
|
||||
intents.members = True
|
||||
intents.message_content = True
|
||||
intents.presences = False
|
||||
|
||||
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||
|
||||
async with db.open():
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
@@ -73,6 +76,7 @@ async def main():
|
||||
config=conf,
|
||||
initial_extensions=[
|
||||
'core',
|
||||
'twitch',
|
||||
'modules',
|
||||
],
|
||||
web_client=session,
|
||||
@@ -82,6 +86,7 @@ async def main():
|
||||
help_command=None,
|
||||
proxy=conf.bot.get('proxy', None),
|
||||
chunk_guilds_at_startup=False,
|
||||
twitch=twitch
|
||||
) as lionbot:
|
||||
ctx_bot.set(lionbot)
|
||||
lionbot.system_monitor.add_component(
|
||||
@@ -89,11 +94,11 @@ async def main():
|
||||
)
|
||||
try:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
||||
logger.info("LionBot initialised, starting!", extra={'action': 'Starting'})
|
||||
await lionbot.start(conf.bot['TOKEN'])
|
||||
except asyncio.CancelledError:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||
logger.info("LionBot closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||
|
||||
|
||||
def _main():
|
||||
|
||||
6
src/brand.py
Normal file
6
src/brand.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import discord
|
||||
|
||||
|
||||
# Theme
|
||||
MAIN_COLOUR = discord.Colour.from_str('#11EA11')
|
||||
ACCENT_COLOUR = discord.Colour.from_str('#EA11EA')
|
||||
@@ -11,6 +11,7 @@ from meta.app import shardname, appname
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
|
||||
from datamodels import DataModel
|
||||
from .data import CoreData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,7 +30,9 @@ class CoreCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = CoreData()
|
||||
self.datamodel = DataModel()
|
||||
bot.db.load_registry(self.data)
|
||||
bot.db.load_registry(self.datamodel)
|
||||
|
||||
self.app_config: Optional[CoreData.AppConfig] = None
|
||||
self.bot_config: Optional[CoreData.BotConfig] = None
|
||||
@@ -43,6 +46,9 @@ class CoreCog(LionCog):
|
||||
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||
|
||||
await self.data.init()
|
||||
await self.datamodel.init()
|
||||
|
||||
# Load the app command cache
|
||||
await self.reload_appcmd_cache()
|
||||
|
||||
|
||||
@@ -47,8 +47,8 @@ class Connector:
|
||||
return AsyncConnectionPool(
|
||||
self._conn_args,
|
||||
open=False,
|
||||
min_size=4,
|
||||
max_size=8,
|
||||
min_size=1,
|
||||
max_size=4,
|
||||
configure=self._setup_connection,
|
||||
kwargs=self._conn_kwargs
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from psycopg import AsyncCursor, sql
|
||||
from psycopg.abc import Query, Params
|
||||
from psycopg._encodings import pgconn_encoding
|
||||
from psycopg._encodings import conn_encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AsyncLoggingCursor(AsyncCursor):
|
||||
elif isinstance(query, (sql.SQL, sql.Composed)):
|
||||
msg = query.as_string(self)
|
||||
elif isinstance(query, bytes):
|
||||
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
|
||||
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
|
||||
else:
|
||||
msg = repr(query)
|
||||
return msg
|
||||
|
||||
337
src/datamodels.py
Normal file
337
src/datamodels.py
Normal file
@@ -0,0 +1,337 @@
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
||||
from data import Registry, RowModel, Table, RegisterEnum
|
||||
from data.columns import Integer, String, Timestamp, Column
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
SUBSCRIBER = 'subscriber',
|
||||
RAID = 'raid',
|
||||
CHEER = 'cheer',
|
||||
PLAIN = 'plain',
|
||||
|
||||
def info(self):
|
||||
if self is EventType.SUBSCRIBER:
|
||||
info = EventTypeInfo(
|
||||
EventType.SUBSCRIBER,
|
||||
DataModel.subscriber_events,
|
||||
("tier", "subscribed_length", "message"),
|
||||
("tier", "subscribed_length", "message"),
|
||||
('subscriber_tier', 'subscriber_length', 'subscriber_message'),
|
||||
)
|
||||
elif self is EventType.RAID:
|
||||
info = EventTypeInfo(
|
||||
EventType.RAID,
|
||||
DataModel.raid_events,
|
||||
('visitor_count',),
|
||||
('viewer_count',),
|
||||
('raid_visitor_count',),
|
||||
)
|
||||
elif self is EventType.CHEER:
|
||||
info = EventTypeInfo(
|
||||
EventType.CHEER,
|
||||
DataModel.cheer_events,
|
||||
('amount', 'cheer_type', 'message'),
|
||||
('amount', 'cheer_type', 'message'),
|
||||
('cheer_amount', 'cheer_type', 'cheer_message'),
|
||||
)
|
||||
elif self is EventType.PLAIN:
|
||||
info = EventTypeInfo(
|
||||
EventType.PLAIN,
|
||||
DataModel.plain_events,
|
||||
('message',),
|
||||
('message',),
|
||||
('plain_message',),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unexpected event type.")
|
||||
return info
|
||||
|
||||
|
||||
|
||||
class EventTypeInfo(NamedTuple):
|
||||
typ: EventType
|
||||
table: Table
|
||||
columns: tuple[str, ...]
|
||||
params: tuple[str, ...]
|
||||
detailcolumns: tuple[str, ...]
|
||||
|
||||
|
||||
class DataModel(Registry):
|
||||
_EventType = RegisterEnum(EventType, 'EventType')
|
||||
|
||||
class UserPreferences(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_preferences (
|
||||
profileid INTEGER PRIMARY KEY REFERENCES user_profiles (profileid) ON DELETE CASCADE,
|
||||
twitch_name TEXT,
|
||||
preferences TEXT
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_preferences'
|
||||
_cache_ = {}
|
||||
|
||||
profileid = Integer(primary=True)
|
||||
twitch_name = String()
|
||||
preferences = String()
|
||||
|
||||
class Dreamer(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE VIEW dreamers AS
|
||||
SELECT
|
||||
user_profiles.profileid AS user_id,
|
||||
user_preferences.twitch_name AS name,
|
||||
profiles_twitch.userid AS twitch_id,
|
||||
user_preferences.preferences AS preferences,
|
||||
user_profiles.created_at AS created_at
|
||||
FROM
|
||||
user_profiles
|
||||
LEFT JOIN profiles_twitch USING (profileid)
|
||||
LEFT JOIN user_preferences USING (profileid);
|
||||
"""
|
||||
_tablename_ = 'dreamers'
|
||||
_readonly_ = True
|
||||
|
||||
user_id = Integer(primary=True)
|
||||
name = String()
|
||||
twitch_id = Integer()
|
||||
preferences = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
class Transaction(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_wallet (
|
||||
transaction_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
|
||||
amount INTEGER NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
reference TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_wallet'
|
||||
_cache_ = {}
|
||||
_immutable_ = True
|
||||
|
||||
transaction_id = Integer(primary=True)
|
||||
user_id = Integer()
|
||||
amount = Integer()
|
||||
description = String()
|
||||
reference = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
|
||||
class StampType(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE stamp_types (
|
||||
stamp_type_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
stamp_type_name TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'stamp_types'
|
||||
_cache_ = {}
|
||||
|
||||
stamp_type_id = Integer(primary=True)
|
||||
stamp_type_name = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
class Document(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE documents (
|
||||
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
document_data TEXT NOT NULL,
|
||||
seal INTEGER NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'documents'
|
||||
_cache_ = {}
|
||||
|
||||
document_id = Integer(primary=True)
|
||||
document_data = Column()
|
||||
seal = Integer()
|
||||
metadata = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
def to_bytes(self):
|
||||
"""
|
||||
Helper method to decode the saved document data to a byte string.
|
||||
This may fail if the saved string is not base64 encoded.
|
||||
"""
|
||||
byts = BytesIO(base64.b64decode(self.document_data))
|
||||
return byts
|
||||
|
||||
class DocumentStamp(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE document_stamps (
|
||||
stamp_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
document_id INTEGER NOT NULL REFERENCES documents (document_id) ON DELETE CASCADE,
|
||||
stamp_type INTEGER NOT NULL REFERENCES stamp_types (stamp_type_id) ON DELETE CASCADE,
|
||||
position_x INTEGER NOT NULL,
|
||||
position_y INTEGER NOT NULL,
|
||||
rotation REAL NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'document_stamps'
|
||||
_cache_ = {}
|
||||
|
||||
stamp_id = Integer(primary=True)
|
||||
document_id = Integer()
|
||||
stamp_type = Integer()
|
||||
position_x = Integer()
|
||||
position_y = Integer()
|
||||
rotation: Column[float] = Column()
|
||||
created_at = Timestamp()
|
||||
|
||||
class Events(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE events (
|
||||
event_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
|
||||
document_id INTEGER REFERENCES documents (document_id) ON DELETE SET NULL,
|
||||
user_name TEXT,
|
||||
event_type EventType NOT NULL,
|
||||
occurred_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (event_id, event_type)
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'events'
|
||||
_cache_ = {}
|
||||
|
||||
event_id = Integer(primary=True)
|
||||
user_id = Integer()
|
||||
document_id = Integer()
|
||||
user_name = String()
|
||||
event_type: Column[EventType] = Column()
|
||||
occured_at = Timestamp()
|
||||
created_at = Timestamp()
|
||||
|
||||
plain_events = Table('plain_events')
|
||||
raid_events = Table('raid_events')
|
||||
cheer_events = Table('cheer_events')
|
||||
subscriber_events = Table('subscriber_events')
|
||||
|
||||
class EventDetails(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE plain_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
||||
message TEXT NOT NULL,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE raid_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
||||
visitor_count INTEGER NOT NULL,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE cheer_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'cheer' CHECK (event_type = 'cheer'),
|
||||
amount INTEGER NOT NULL,
|
||||
cheer_type TEXT,
|
||||
message TEXT,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE subscriber_events (
|
||||
event_id integer PRIMARY KEY,
|
||||
event_type EventType NOT NULL DEFAULT 'subscriber' CHECK (event_type = 'subscriber'),
|
||||
subscribed_length INTEGER NOT NULL,
|
||||
tier INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE VIEW event_details AS
|
||||
SELECT
|
||||
events.event_id AS event_id,
|
||||
events.user_id AS user_id,
|
||||
events.document_id AS document_id,
|
||||
events.user_name AS user_name,
|
||||
events.event_type AS event_type,
|
||||
events.occurred_at AS occurred_at,
|
||||
events.created_at AS created_at,
|
||||
plain_events.message AS plain_message,
|
||||
raid_events.visitor_count AS raid_visitor_count,
|
||||
cheer_events.amount AS cheer_amount,
|
||||
cheer_events.cheer_type AS cheer_type,
|
||||
cheer_events.message AS cheer_message,
|
||||
subscriber_events.subscribed_length AS subscriber_length,
|
||||
subscriber_events.tier AS subscriber_tier,
|
||||
subscriber_events.message AS subscriber_message,
|
||||
documents.seal AS document_seal
|
||||
FROM
|
||||
events
|
||||
LEFT JOIN plain_events USING (event_id)
|
||||
LEFT JOIN raid_events USING (event_id)
|
||||
LEFT JOIN cheer_events USING (event_id)
|
||||
LEFT JOIN subscriber_events USING (event_id)
|
||||
LEFT JOIN documents USING (document_id)
|
||||
ORDER BY events.occurred_at ASC;
|
||||
"""
|
||||
_tablename_ = 'event_details'
|
||||
_readonly_ = True
|
||||
|
||||
event_id = Integer(primary=True)
|
||||
user_id = Integer()
|
||||
document_id = Integer()
|
||||
user_name = String()
|
||||
event_type: Column[EventType] = Column()
|
||||
occurred_at = Timestamp()
|
||||
created_at = Timestamp()
|
||||
plain_message = String()
|
||||
raid_visitor_count = Integer()
|
||||
cheer_amount = Integer()
|
||||
cheer_type = String()
|
||||
cheer_message = String()
|
||||
subscriber_length = Integer()
|
||||
subscriber_tier = Integer()
|
||||
subscriber_message = String()
|
||||
document_seal = Integer()
|
||||
|
||||
|
||||
class Specimen(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_specimens (
|
||||
specimen_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
owner_id INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE,
|
||||
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
forgotten_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
|
||||
"""
|
||||
_tablename_ = 'user_specimens'
|
||||
_cache_ = {}
|
||||
|
||||
specimen_id = Integer(primary=True)
|
||||
owner_id = Integer(primary=True)
|
||||
born_at = Timestamp()
|
||||
forgotten_at = Timestamp()
|
||||
76
src/meta/CrocBot.py
Normal file
76
src/meta/CrocBot.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import logging
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
from twitchio.ext import pubsub
|
||||
from twitchio.ext.commands.core import itertools
|
||||
|
||||
from data import Database
|
||||
|
||||
from .config import Conf
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrocBot(commands.Bot):
|
||||
def __init__(self, *args,
|
||||
config: Conf,
|
||||
data: Database,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.data = data
|
||||
self.pubsub = pubsub.PubSubPool(self)
|
||||
|
||||
self._member_cache = defaultdict(dict)
|
||||
|
||||
async def event_ready(self):
|
||||
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||
|
||||
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
|
||||
self._member_cache[channel.name][user.name] = user
|
||||
|
||||
async def event_message(self, message: twitchio.Message):
|
||||
if message.channel and message.author:
|
||||
self._member_cache[message.channel.name][message.author.name] = message.author
|
||||
await self.handle_commands(message)
|
||||
|
||||
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
|
||||
if userstr.startswith('@'):
|
||||
matching = False
|
||||
userstr = userstr.strip('@ ')
|
||||
|
||||
result = None
|
||||
if matching and len(userstr) >= 3:
|
||||
lowered = userstr.lower()
|
||||
full_matches = []
|
||||
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
|
||||
matchstr = user.name.lower()
|
||||
print(matchstr)
|
||||
if matchstr.startswith(lowered):
|
||||
result = user
|
||||
break
|
||||
if lowered in matchstr:
|
||||
full_matches.append(user)
|
||||
if result is None and full_matches:
|
||||
result = full_matches[0]
|
||||
print(result)
|
||||
|
||||
if result is None:
|
||||
lookup = userstr
|
||||
elif result.id is None:
|
||||
lookup = result.name
|
||||
else:
|
||||
lookup = None
|
||||
|
||||
if lookup:
|
||||
found = await self.fetch_users(names=[lookup])
|
||||
if found:
|
||||
result = found[0]
|
||||
|
||||
# No matches found
|
||||
return result
|
||||
@@ -9,6 +9,7 @@ from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
||||
from aiohttp import ClientSession
|
||||
from twitchAPI.twitch import Twitch
|
||||
|
||||
from data import Database
|
||||
from utils.lib import tabulate
|
||||
@@ -23,6 +24,8 @@ from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.cog import CoreCog
|
||||
from twitch.cog import TwitchAuthCog
|
||||
from modules.profiles.cog import ProfileCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,7 +34,9 @@ class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession,
|
||||
testing_guilds: List[int] = [], **kwargs
|
||||
twitch: Twitch,
|
||||
testing_guilds: List[int] = [],
|
||||
**kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -43,6 +48,7 @@ class LionBot(Bot):
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
self.twitch = twitch
|
||||
|
||||
self.system_monitor = SystemMonitor()
|
||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||
@@ -101,6 +107,14 @@ class LionBot(Bot):
|
||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: str) -> Optional[Cog]:
|
||||
...
|
||||
@@ -189,7 +203,7 @@ class LionBot(Bot):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict()
|
||||
cmd_str = ctx.command.app_command.to_dict(self.tree)
|
||||
else:
|
||||
cmd_str = str(ctx.command)
|
||||
try:
|
||||
|
||||
@@ -133,7 +133,7 @@ class LionTree(CommandTree):
|
||||
return
|
||||
|
||||
set_logging_context(action=f"Run {command.qualified_name}")
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
|
||||
try:
|
||||
await command._invoke_with_namespace(interaction, namespace)
|
||||
except AppCommandError as e:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
this_package = 'modules'
|
||||
|
||||
active = [
|
||||
'.profiles',
|
||||
'.sysadmin',
|
||||
'.dreamspace',
|
||||
]
|
||||
|
||||
|
||||
|
||||
8
src/modules/dreamspace/__init__.py
Normal file
8
src/modules/dreamspace/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import DreamCog
|
||||
await bot.add_cog(DreamCog(bot))
|
||||
70
src/modules/dreamspace/cog.py
Normal file
70
src/modules/dreamspace/cog.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord import app_commands as appcmds
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
from meta import LionCog, LionBot, LionContext
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
|
||||
from . import logger
|
||||
from .ui.docviewer import DocumentViewer
|
||||
|
||||
|
||||
class DreamCog(LionCog):
|
||||
"""
|
||||
Discord-facting interface for Dreamspace Adventures
|
||||
"""
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.core.datamodel
|
||||
|
||||
async def cog_load(self):
|
||||
pass
|
||||
|
||||
@log_wrap(action="Dreamer migration")
|
||||
async def migrate_dreamer(self, source_profile, target_profile):
|
||||
"""
|
||||
Called when two dreamer profiles need to merge.
|
||||
|
||||
For example, when a user links a second twitch profile.
|
||||
|
||||
:TODO-MARKER:
|
||||
Most of the migration logic is simple, e.g. just update the profileid
|
||||
on the old events to the new profile.
|
||||
The same applies to transactions and probably to inventory items.
|
||||
However, there are some subtle choices to make, such as what to do
|
||||
if both the old and the new profile have an active specimen?
|
||||
A profile can only have one active specimen at a time.
|
||||
There is also the question of how to merge user preferences, when those exist.
|
||||
"""
|
||||
...
|
||||
|
||||
# User command: view their dreamer card, wallet inventory etc
|
||||
|
||||
# (Admin): View events/documents matching certain criteria
|
||||
|
||||
# (User): View own event cards with info?
|
||||
# Let's make a demo viewer which lists their event cards and let's them open one via select?
|
||||
# /documents -> Show a paged list of documents, select option displays the document in a viewer
|
||||
@cmds.hybrid_command(
|
||||
name='documents',
|
||||
description="View your printer log!"
|
||||
)
|
||||
async def documents_cmd(self, ctx: LionContext):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
events = await self.data.Events.fetch_where(user_id=profile.profileid)
|
||||
docids = [event.document_id for event in events if event.document_id is not None]
|
||||
if not docids:
|
||||
await ctx.error_reply("You don't have any documents yet!")
|
||||
return
|
||||
|
||||
view = DocumentViewer(self.bot, ctx.interaction.user.id, filter=(self.data.Document.document_id == docids))
|
||||
await view.run(ctx.interaction)
|
||||
|
||||
# (User): View Specimen information
|
||||
|
||||
|
||||
|
||||
36
src/modules/dreamspace/events.py
Normal file
36
src/modules/dreamspace/events.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from datamodels import DataModel
|
||||
|
||||
|
||||
class Event:
|
||||
_typs = {}
|
||||
|
||||
def __init__(self, event_row: DataModel.Events, **kwargs):
|
||||
self.row = event_row
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
...
|
||||
|
||||
async def get_document(self):
|
||||
...
|
||||
|
||||
async def get_user(self):
|
||||
...
|
||||
|
||||
|
||||
class Document:
|
||||
def as_bytes(self):
|
||||
...
|
||||
|
||||
async def get_stamps(self):
|
||||
...
|
||||
|
||||
async def refresh(self):
|
||||
...
|
||||
|
||||
|
||||
class User:
|
||||
...
|
||||
|
||||
|
||||
class Stamp:
|
||||
...
|
||||
171
src/modules/dreamspace/ui/docviewer.py
Normal file
171
src/modules/dreamspace/ui/docviewer.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import binascii
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
|
||||
import discord
|
||||
from discord.ui.select import select, Select, UserSelect
|
||||
from discord.ui.button import button, Button
|
||||
from discord.ui.text_input import TextInput
|
||||
from discord.enums import ButtonStyle, TextStyle
|
||||
from discord.components import SelectOption
|
||||
|
||||
from datamodels import DataModel
|
||||
from meta import LionBot, conf
|
||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||
from data import ORDER, Condition
|
||||
|
||||
from utils.ui import MessageUI, input
|
||||
from utils.lib import MessageArgs, tabulate, utc_now
|
||||
|
||||
from .. import logger
|
||||
|
||||
class DocumentViewer(MessageUI):
|
||||
"""
|
||||
Simple pager which displays a filtered list of Documents.
|
||||
"""
|
||||
block_len = 5
|
||||
|
||||
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
|
||||
super().__init__(callerid=callerid, **kwargs)
|
||||
|
||||
self.bot = bot
|
||||
self.data: DataModel = bot.core.datamodel
|
||||
self.filter = filter
|
||||
|
||||
# Paging state
|
||||
self._pagen = 0
|
||||
self.blocks = [[]]
|
||||
|
||||
@property
|
||||
def page_count(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@property
|
||||
def pagen(self):
|
||||
self._pagen %= self.page_count
|
||||
return self._pagen
|
||||
|
||||
@pagen.setter
|
||||
def pagen(self, value):
|
||||
self._pagen = value % self.page_count
|
||||
|
||||
@property
|
||||
def current_page(self):
|
||||
return self.blocks[self.pagen]
|
||||
|
||||
# ----- UI Components -----
|
||||
|
||||
# Page backwards
|
||||
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True)
|
||||
self.pagen -= 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Jump to page
|
||||
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
|
||||
async def jump_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Jump to page button.
|
||||
"""
|
||||
try:
|
||||
interaction, value = await input(
|
||||
press,
|
||||
title="Jump to page",
|
||||
question="Page number to jump to"
|
||||
)
|
||||
value = value.strip()
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
if not value.lstrip('- ').isdigit():
|
||||
error = discord.Embed(title="Invalid page number, please try again!",
|
||||
colour=discord.Colour.brand_red())
|
||||
await interaction.response.send_message(embed=error, ephemeral=True)
|
||||
else:
|
||||
await interaction.response.defer(thinking=True)
|
||||
pagen = int(value.lstrip('- '))
|
||||
if value.startswith('-'):
|
||||
pagen = -1 * pagen
|
||||
elif pagen > 0:
|
||||
pagen = pagen - 1
|
||||
self.pagen = pagen
|
||||
await self.refresh(thinking=interaction)
|
||||
|
||||
async def jump_button_refresh(self):
|
||||
component = self.jump_button
|
||||
component.label = f"{self.pagen + 1}/{self.page_count}"
|
||||
component.disabled = (self.page_count <= 1)
|
||||
|
||||
# Page forwards
|
||||
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True)
|
||||
self.pagen += 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Quit
|
||||
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Quit the UI.
|
||||
"""
|
||||
await press.response.defer()
|
||||
# if self.child_viewer:
|
||||
# await self.child_viewer.quit()
|
||||
await self.quit()
|
||||
|
||||
# ----- UI Flow -----
|
||||
|
||||
async def make_message(self) -> MessageArgs:
|
||||
files = []
|
||||
embeds = []
|
||||
for doc in self.current_page:
|
||||
try:
|
||||
imagedata = doc.to_bytes()
|
||||
imagedata.seek(0)
|
||||
except binascii.Error:
|
||||
continue
|
||||
fn = f"doc-{doc.document_id}.png"
|
||||
file = discord.File(imagedata, fn)
|
||||
embed = discord.Embed()
|
||||
embed.set_image(url=f"attachment://{fn}")
|
||||
files.append(file)
|
||||
embeds.append(embed)
|
||||
|
||||
if not embeds:
|
||||
embed = discord.Embed(description="You don't have any documents yet!")
|
||||
embeds.append(embed)
|
||||
|
||||
print(f"FILES: {files}")
|
||||
|
||||
return MessageArgs(files=files, embeds=embeds)
|
||||
|
||||
async def refresh_layout(self):
|
||||
to_refresh = (
|
||||
self.jump_button_refresh(),
|
||||
)
|
||||
await asyncio.gather(*to_refresh)
|
||||
if self.page_count > 1:
|
||||
page_line = (
|
||||
self.prev_button,
|
||||
self.jump_button,
|
||||
self.quit_button,
|
||||
self.next_button,
|
||||
)
|
||||
else:
|
||||
page_line = (self.quit_button,)
|
||||
|
||||
self.set_layout(page_line)
|
||||
|
||||
async def reload(self):
|
||||
docs = await self.data.Document.fetch_where(self.filter).order_by('created_at', ORDER.DESC)
|
||||
blocks = [
|
||||
docs[i:i+self.block_len]
|
||||
for i in range(0, len(docs), self.block_len)
|
||||
]
|
||||
self.blocks = blocks or [[]]
|
||||
|
||||
406
src/modules/dreamspace/ui/eventviewer.py
Normal file
406
src/modules/dreamspace/ui/eventviewer.py
Normal file
@@ -0,0 +1,406 @@
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
|
||||
import discord
|
||||
from discord.ui.select import select, Select, UserSelect
|
||||
from discord.ui.button import button, Button
|
||||
from discord.ui.text_input import TextInput
|
||||
from discord.enums import ButtonStyle, TextStyle
|
||||
from discord.components import SelectOption
|
||||
|
||||
from datamodels import DataModel
|
||||
from meta import LionBot, conf
|
||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||
from data import ORDER, Condition
|
||||
|
||||
from utils.ui import MessageUI, input
|
||||
from utils.lib import MessageArgs, tabulate, utc_now
|
||||
|
||||
from .. import logger
|
||||
|
||||
|
||||
class EventsUI(MessageUI):
|
||||
block_len = 10
|
||||
|
||||
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
|
||||
super().__init__(callerid=callerid, **kwargs)
|
||||
|
||||
self.bot = bot
|
||||
self.data: DataModel = bot.core.datamodel
|
||||
self.filter = Condition
|
||||
|
||||
# Paging state
|
||||
self._pagen = 0
|
||||
self.blocks = [[]]
|
||||
|
||||
@property
|
||||
def page_count(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@property
|
||||
def pagen(self):
|
||||
self._pagen %= self.page_count
|
||||
return self._pagen
|
||||
|
||||
@pagen.setter
|
||||
def pagen(self, value):
|
||||
self._pagen = value % self.page_count
|
||||
|
||||
@property
|
||||
def current_page(self):
|
||||
return self.blocks[self.pagen]
|
||||
|
||||
# ----- UI Components -----
|
||||
|
||||
# Page backwards
|
||||
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True)
|
||||
self.pagen -= 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Jump to page
|
||||
|
||||
# Page forwards
|
||||
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True)
|
||||
self.pagen += 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Quit
|
||||
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Quit the UI.
|
||||
"""
|
||||
await press.response.defer()
|
||||
# if self.child_viewer:
|
||||
# await self.child_viewer.quit()
|
||||
await self.quit()
|
||||
|
||||
# ----- UI Flow -----
|
||||
|
||||
async def make_message(self) -> MessageArgs:
|
||||
...
|
||||
|
||||
async def refresh_layout(self):
|
||||
...
|
||||
|
||||
async def reload(self):
|
||||
...
|
||||
|
||||
|
||||
|
||||
class TicketListUI(MessageUI):
|
||||
# Select Ticket
|
||||
@select(
|
||||
cls=Select,
|
||||
placeholder="TICKETS_MENU_PLACEHOLDER",
|
||||
min_values=1, max_values=1
|
||||
)
|
||||
async def tickets_menu(self, selection: discord.Interaction, selected: Select):
|
||||
await selection.response.defer(thinking=True, ephemeral=True)
|
||||
if selected.values:
|
||||
ticketid = int(selected.values[0])
|
||||
ticket = await Ticket.fetch_ticket(self.bot, ticketid)
|
||||
ticketui = TicketUI(self.bot, ticket, self._callerid)
|
||||
if self.child_ticket:
|
||||
await self.child_ticket.quit()
|
||||
self.child_ticket = ticketui
|
||||
await ticketui.run(selection)
|
||||
|
||||
async def tickets_menu_refresh(self):
|
||||
menu = self.tickets_menu
|
||||
t = self.bot.translator.t
|
||||
menu.placeholder = t(_p(
|
||||
'ui:tickets|menu:tickets|placeholder',
|
||||
"Select Ticket"
|
||||
))
|
||||
options = []
|
||||
for ticket in self.current_page:
|
||||
option = SelectOption(
|
||||
label=f"Ticket #{ticket.data.guild_ticketid}",
|
||||
value=str(ticket.data.ticketid)
|
||||
)
|
||||
options.append(option)
|
||||
menu.options = options
|
||||
|
||||
# Backwards
|
||||
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True, ephemeral=True)
|
||||
self.pagen -= 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Jump to page
|
||||
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
|
||||
async def jump_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Jump-to-page button.
|
||||
Loads a page-switch dialogue.
|
||||
"""
|
||||
t = self.bot.translator.t
|
||||
try:
|
||||
interaction, value = await input(
|
||||
press,
|
||||
title=t(_p(
|
||||
'ui:tickets|button:jump|input:title',
|
||||
"Jump to page"
|
||||
)),
|
||||
question=t(_p(
|
||||
'ui:tickets|button:jump|input:question',
|
||||
"Page number to jump to"
|
||||
))
|
||||
)
|
||||
value = value.strip()
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
if not value.lstrip('- ').isdigit():
|
||||
error_embed = discord.Embed(
|
||||
title=t(_p(
|
||||
'ui:tickets|button:jump|error:invalid_page',
|
||||
"Invalid page number, please try again!"
|
||||
)),
|
||||
colour=discord.Colour.brand_red()
|
||||
)
|
||||
await interaction.response.send_message(embed=error_embed, ephemeral=True)
|
||||
else:
|
||||
await interaction.response.defer(thinking=True)
|
||||
pagen = int(value.lstrip('- '))
|
||||
if value.startswith('-'):
|
||||
pagen = -1 * pagen
|
||||
elif pagen > 0:
|
||||
pagen = pagen - 1
|
||||
self.pagen = pagen
|
||||
await self.refresh(thinking=interaction)
|
||||
|
||||
async def jump_button_refresh(self):
|
||||
component = self.jump_button
|
||||
component.label = f"{self.pagen + 1}/{self.page_count}"
|
||||
component.disabled = (self.page_count <= 1)
|
||||
|
||||
# Forward
|
||||
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||
await press.response.defer(thinking=True)
|
||||
self.pagen += 1
|
||||
await self.refresh(thinking=press)
|
||||
|
||||
# Quit
|
||||
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Quit the UI.
|
||||
"""
|
||||
await press.response.defer()
|
||||
if self.child_ticket:
|
||||
await self.child_ticket.quit()
|
||||
await self.quit()
|
||||
|
||||
# ----- UI Flow -----
|
||||
def _format_ticket(self, ticket) -> str:
|
||||
"""
|
||||
Format a ticket into a single embed line.
|
||||
"""
|
||||
components = (
|
||||
"[#{ticketid}]({link})",
|
||||
"{created}",
|
||||
"`{type}[{state}]`",
|
||||
"<@{targetid}>",
|
||||
"{content}",
|
||||
)
|
||||
|
||||
formatstr = ' | '.join(components)
|
||||
|
||||
data = ticket.data
|
||||
if not data.content:
|
||||
content = 'No Content'
|
||||
elif len(data.content) > 100:
|
||||
content = data.content[:97] + '...'
|
||||
else:
|
||||
content = data.content
|
||||
|
||||
ticketstr = formatstr.format(
|
||||
ticketid=data.guild_ticketid,
|
||||
link=ticket.jump_url or 'https://lionbot.org',
|
||||
created=discord.utils.format_dt(data.created_at, 'd'),
|
||||
type=data.ticket_type.name,
|
||||
state=data.ticket_state.name,
|
||||
targetid=data.targetid,
|
||||
content=content,
|
||||
)
|
||||
if data.ticket_state is TicketState.PARDONED:
|
||||
ticketstr = f"~~{ticketstr}~~"
|
||||
return ticketstr
|
||||
|
||||
async def make_message(self) -> MessageArgs:
|
||||
t = self.bot.translator.t
|
||||
embed = discord.Embed(
|
||||
title=t(_p(
|
||||
'ui:tickets|embed|title',
|
||||
"Moderation Tickets in {guild}"
|
||||
)).format(guild=self.guild.name),
|
||||
timestamp=utc_now()
|
||||
)
|
||||
tickets = self.current_page
|
||||
if tickets:
|
||||
desc = '\n'.join(self._format_ticket(ticket) for ticket in tickets)
|
||||
else:
|
||||
desc = t(_p(
|
||||
'ui:tickets|embed|desc:no_tickets',
|
||||
"No tickets matching the given criteria!"
|
||||
))
|
||||
embed.description = desc
|
||||
|
||||
filterstr = self.filters.formatted()
|
||||
if filterstr:
|
||||
embed.add_field(
|
||||
name=t(_p(
|
||||
'ui:tickets|embed|field:filters|name',
|
||||
"Filters"
|
||||
)),
|
||||
value=filterstr,
|
||||
inline=False
|
||||
)
|
||||
|
||||
return MessageArgs(embed=embed)
|
||||
|
||||
async def refresh_layout(self):
|
||||
to_refresh = (
|
||||
self.edit_filter_button_refresh(),
|
||||
self.select_ticket_button_refresh(),
|
||||
self.pardon_button_refresh(),
|
||||
self.tickets_menu_refresh(),
|
||||
self.filter_type_menu_refresh(),
|
||||
self.filter_state_menu_refresh(),
|
||||
self.filter_target_menu_refresh(),
|
||||
self.jump_button_refresh(),
|
||||
)
|
||||
await asyncio.gather(*to_refresh)
|
||||
|
||||
action_line = (
|
||||
self.edit_filter_button,
|
||||
self.select_ticket_button,
|
||||
self.pardon_button,
|
||||
)
|
||||
|
||||
if self.page_count > 1:
|
||||
page_line = (
|
||||
self.prev_button,
|
||||
self.jump_button,
|
||||
self.quit_button,
|
||||
self.next_button,
|
||||
)
|
||||
else:
|
||||
page_line = ()
|
||||
action_line = (*action_line, self.quit_button)
|
||||
|
||||
if self.show_filters:
|
||||
menus = (
|
||||
(self.filter_type_menu,),
|
||||
(self.filter_state_menu,),
|
||||
(self.filter_target_menu,),
|
||||
)
|
||||
elif self.show_tickets and self.current_page:
|
||||
menus = ((self.tickets_menu,),)
|
||||
else:
|
||||
menus = ()
|
||||
|
||||
self.set_layout(
|
||||
action_line,
|
||||
*menus,
|
||||
page_line,
|
||||
)
|
||||
|
||||
async def reload(self):
|
||||
tickets = await Ticket.fetch_tickets(
|
||||
self.bot,
|
||||
*self.filters.conditions(),
|
||||
guildid=self.guild.id,
|
||||
)
|
||||
blocks = [
|
||||
tickets[i:i+self.block_len]
|
||||
for i in range(0, len(tickets), self.block_len)
|
||||
]
|
||||
self.blocks = blocks or [[]]
|
||||
|
||||
|
||||
class TicketUI(MessageUI):
|
||||
def __init__(self, bot: LionBot, ticket: Ticket, callerid: int, **kwargs):
|
||||
super().__init__(callerid=callerid, **kwargs)
|
||||
|
||||
self.bot = bot
|
||||
self.ticket = ticket
|
||||
|
||||
# ----- API -----
|
||||
|
||||
# ----- UI Components -----
|
||||
# Pardon Ticket
|
||||
@button(
|
||||
label="PARDON_BUTTON_PLACEHOLDER",
|
||||
style=ButtonStyle.red
|
||||
)
|
||||
async def pardon_button(self, press: discord.Interaction, pressed: Button):
|
||||
t = self.bot.translator.t
|
||||
|
||||
modal_title = t(_p(
|
||||
'ui:ticket|button:pardon|modal:reason|title',
|
||||
"Pardon Moderation Ticket"
|
||||
))
|
||||
input_field = TextInput(
|
||||
label=t(_p(
|
||||
'ui:ticket|button:pardon|modal:reason|field|label',
|
||||
"Why are you pardoning this ticket?"
|
||||
)),
|
||||
style=TextStyle.long,
|
||||
min_length=0,
|
||||
max_length=1024,
|
||||
)
|
||||
try:
|
||||
interaction, reason = await input(
|
||||
press, modal_title, field=input_field, timeout=300,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResponseTimedOut
|
||||
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
await self.ticket.pardon(modid=press.user.id, reason=reason)
|
||||
await self.refresh(thinking=interaction)
|
||||
|
||||
|
||||
async def pardon_button_refresh(self):
|
||||
button = self.pardon_button
|
||||
t = self.bot.translator.t
|
||||
button.label = t(_p(
|
||||
'ui:ticket|button:pardon|label',
|
||||
"Pardon"
|
||||
))
|
||||
button.disabled = (self.ticket.data.ticket_state is TicketState.PARDONED)
|
||||
|
||||
# Quit
|
||||
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||
"""
|
||||
Quit the UI.
|
||||
"""
|
||||
await press.response.defer()
|
||||
await self.quit()
|
||||
|
||||
# ----- UI Flow -----
|
||||
async def make_message(self) -> MessageArgs:
|
||||
return await self.ticket.make_message()
|
||||
|
||||
async def refresh_layout(self):
|
||||
await self.pardon_button_refresh()
|
||||
self.set_layout(
|
||||
(self.pardon_button, self.quit_button,)
|
||||
)
|
||||
|
||||
async def reload(self):
|
||||
await self.ticket.data.refresh()
|
||||
0
src/modules/dreamspace/ui/specimen.py
Normal file
0
src/modules/dreamspace/ui/specimen.py
Normal file
8
src/modules/profiles/__init__.py
Normal file
8
src/modules/profiles/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import ProfileCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(ProfileCog(bot))
|
||||
455
src/modules/profiles/cog.py
Normal file
455
src/modules/profiles/cog.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional, overload
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from discord import app_commands as appcmds
|
||||
from discord.ext import commands as cmds
|
||||
from twitchAPI.helper import first
|
||||
from twitchAPI.type import AuthScope
|
||||
import twitchio
|
||||
from twitchAPI.object.api import TwitchUser
|
||||
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import LionCog, LionBot, LionContext
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
from .profile import UserProfile
|
||||
from .community import Community
|
||||
|
||||
from .ui import TwitchLinkStatic, TwitchLinkFlow
|
||||
|
||||
|
||||
class ProfileCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
self.data = bot.db.load_registry(ProfileData())
|
||||
|
||||
self._profile_migrators = {}
|
||||
self._comm_migrators = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
self.bot.add_view(TwitchLinkStatic(timeout=None))
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
return True
|
||||
|
||||
# Profile API
|
||||
def add_profile_migrator(self, migrator, name=None):
|
||||
name = name or migrator.__name__
|
||||
self._profile_migrators[name or migrator.__name__] = migrator
|
||||
|
||||
logger.info(
|
||||
f"Added user profile migrator {name}: {migrator}"
|
||||
)
|
||||
return migrator
|
||||
|
||||
def del_profile_migrator(self, name: str):
|
||||
migrator = self._profile_migrators.pop(name, None)
|
||||
|
||||
logger.info(
|
||||
f"Removed user profile migrator {name}: {migrator}"
|
||||
)
|
||||
|
||||
@log_wrap(action="profile migration")
|
||||
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
|
||||
logger.info(
|
||||
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
|
||||
)
|
||||
results = []
|
||||
# Wrap this in a transaction so if something goes wrong with migration,
|
||||
# we roll back safely (although this may mess up caches)
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
for name, migrator in self._profile_migrators.items():
|
||||
try:
|
||||
result = await migrator(source_profile, target_profile)
|
||||
if result:
|
||||
results.append(result)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unexpected exception running user profile migrator {name} "
|
||||
f"migrating {source_profile!r} to {target_profile!r}."
|
||||
)
|
||||
raise
|
||||
|
||||
# Move all Discord and Twitch profile references over to the new profile
|
||||
discord_rows = await self.data.DiscordProfileRow.table.update_where(
|
||||
profileid=source_profile.profileid
|
||||
).set(profileid=target_profile.profileid)
|
||||
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
|
||||
|
||||
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
|
||||
profileid=source_profile.profileid
|
||||
).set(profileid=target_profile.profileid)
|
||||
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
|
||||
|
||||
# And then mark the old profile as migrated
|
||||
await source_profile.profile_row.update(migrated=target_profile.profileid)
|
||||
results.append("Marking old profile as migrated.. finished!")
|
||||
return results
|
||||
|
||||
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
|
||||
"""
|
||||
Fetch a UserProfile by the given id.
|
||||
"""
|
||||
return await UserProfile.fetch(self.bot, profile_id=profile_id)
|
||||
|
||||
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
|
||||
"""
|
||||
Fetch or create a UserProfile from the provided discord account.
|
||||
"""
|
||||
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
|
||||
if profile is None:
|
||||
profile = await UserProfile.create_from_discord(self.bot, user)
|
||||
return profile
|
||||
|
||||
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
|
||||
"""
|
||||
Fetch or create a UserProfile from the provided twitch account.
|
||||
"""
|
||||
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||
if profile is None:
|
||||
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||
return profile
|
||||
|
||||
# Community API
|
||||
def add_community_migrator(self, migrator, name=None):
|
||||
name = name or migrator.__name__
|
||||
self._comm_migrators[name or migrator.__name__] = migrator
|
||||
|
||||
logger.info(
|
||||
f"Added community migrator {name}: {migrator}"
|
||||
)
|
||||
return migrator
|
||||
|
||||
def del_community_migrator(self, name: str):
|
||||
migrator = self._comm_migrators.pop(name, None)
|
||||
|
||||
logger.info(
|
||||
f"Removed community migrator {name}: {migrator}"
|
||||
)
|
||||
|
||||
@log_wrap(action="community migration")
|
||||
async def migrate_community(self, source_comm, target_comm) -> list[str]:
|
||||
logger.info(
|
||||
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
|
||||
)
|
||||
results = []
|
||||
# Wrap this in a transaction so if something goes wrong with migration,
|
||||
# we roll back safely (although this may mess up caches)
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
for name, migrator in self._comm_migrators.items():
|
||||
try:
|
||||
result = await migrator(source_comm, target_comm)
|
||||
if result:
|
||||
results.append(result)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unexpected exception running community migrator {name} "
|
||||
f"migrating {source_comm!r} to {target_comm!r}."
|
||||
)
|
||||
raise
|
||||
|
||||
# Move all Discord and Twitch community preferences over to the new profile
|
||||
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
|
||||
profileid=source_comm.communityid
|
||||
).set(communityid=target_comm.communityid)
|
||||
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
|
||||
|
||||
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
|
||||
communityid=source_comm.communityid
|
||||
).set(communityid=target_comm.communityid)
|
||||
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
|
||||
|
||||
# And then mark the old community as migrated
|
||||
await source_comm.update(migrated=target_comm.communityid)
|
||||
results.append("Marking old community as migrated.. finished!")
|
||||
return results
|
||||
|
||||
async def fetch_community_by_id(self, community_id: int) -> Community:
|
||||
"""
|
||||
Fetch a Community by the given id.
|
||||
"""
|
||||
return await Community.fetch(self.bot, community_id=community_id)
|
||||
|
||||
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
|
||||
"""
|
||||
Fetch or create a Community from the provided discord guild.
|
||||
"""
|
||||
comm = await Community.fetch_from_discordid(self.bot, guild.id)
|
||||
if comm is None:
|
||||
comm = await Community.create_from_discord(self.bot, guild)
|
||||
return comm
|
||||
|
||||
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
|
||||
"""
|
||||
Fetch or create a Community from the provided twitch account.
|
||||
"""
|
||||
community = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||
if community is None:
|
||||
community = await Community.create_from_twitch(self.bot, user)
|
||||
return community
|
||||
|
||||
# ----- Admin Commands -----
|
||||
@cmds.hybrid_command(
|
||||
name='linkoffer',
|
||||
description="Send a message with a permanent button for profile linking"
|
||||
)
|
||||
@appcmds.default_permissions(manage_guild=True)
|
||||
async def linkoffer_cmd(self, ctx: LionContext):
|
||||
view = TwitchLinkStatic(timeout=None)
|
||||
await ctx.channel.send(embed=view.embed, view=view)
|
||||
|
||||
# ----- Profile Commands -----
|
||||
@cmds.hybrid_group(
|
||||
name='profiles',
|
||||
description="Base comand group for user profiles."
|
||||
)
|
||||
async def profiles_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@profiles_grp.group(
|
||||
name='link',
|
||||
description="Base command group for linking profiles"
|
||||
)
|
||||
async def profiles_link_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@profiles_link_grp.command(
|
||||
name='twitch',
|
||||
description="Link a twitch account to your current profile."
|
||||
)
|
||||
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
flowui = TwitchLinkFlow(self.bot, ctx.author, callerid=ctx.author.id)
|
||||
await flowui.run(ctx.interaction)
|
||||
await flowui.wait()
|
||||
|
||||
async def old_profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Ask the user to go through auth to get their userid
|
||||
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||
flow = await auth_cog.start_auth()
|
||||
message = await ctx.reply(
|
||||
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
|
||||
"to Twitch."
|
||||
)
|
||||
authrow = await flow.run()
|
||||
await message.edit(
|
||||
content="Authentication Complete! Beginning profile merge..."
|
||||
)
|
||||
|
||||
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||
# if not results:
|
||||
# logger.error(
|
||||
# f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
# )
|
||||
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
# return
|
||||
|
||||
# user = results[0]
|
||||
try:
|
||||
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
|
||||
exc_info=True
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
if user is None:
|
||||
logger.error(
|
||||
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
|
||||
# Retrieve author's profile if it exists
|
||||
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
|
||||
|
||||
# Check if the twitch-side user has a profile
|
||||
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||
|
||||
if author_profile and source_profile is None:
|
||||
# All we need to do is attach the twitch row
|
||||
await author_profile.attach_twitch(user.id)
|
||||
await message.edit(
|
||||
content=f"Successfully added Twitch account **{user.display_name}**! There was no profile data to merge."
|
||||
)
|
||||
elif source_profile and author_profile is None:
|
||||
# Attach the discord row to the profile
|
||||
await source_profile.attach_discord(ctx.author.id)
|
||||
await message.edit(
|
||||
content=f"Successfully connected to Twitch profile **{user.display_name}**! There was no profile data to merge."
|
||||
)
|
||||
elif source_profile is None and author_profile is None:
|
||||
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
|
||||
await profile.attach_twitch(user.id)
|
||||
|
||||
await message.edit(
|
||||
content=f"Opened a new user profile for you and linked Twitch account **{user.display_name}**."
|
||||
)
|
||||
elif author_profile.profileid == source_profile.profileid:
|
||||
await message.edit(
|
||||
content=f"The Twitch account **{user.display_name}** is already linked to your profile!"
|
||||
)
|
||||
else:
|
||||
# Migrate the existing profile data to the new profiles
|
||||
try:
|
||||
results = await self.migrate_profile(source_profile, author_profile)
|
||||
except Exception:
|
||||
await ctx.error_reply(
|
||||
"An issue was encountered while merging your account profiles!\n"
|
||||
"Migration rolled back, no data has been lost.\n"
|
||||
"The developer has been notified. Please try again later!"
|
||||
)
|
||||
raise
|
||||
|
||||
content = '\n'.join((
|
||||
"## Connecting Twitch account and merging profiles...",
|
||||
*results,
|
||||
"**Successfully linked account and merged profile data!**"
|
||||
))
|
||||
await message.edit(content=content)
|
||||
|
||||
# ----- Community Commands -----
|
||||
@cmds.hybrid_group(
|
||||
name='community',
|
||||
description="Base comand group for community profiles."
|
||||
)
|
||||
async def community_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@community_grp.group(
|
||||
name='link',
|
||||
description="Base command group for linking communities"
|
||||
)
|
||||
async def community_link_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@community_link_grp.command(
|
||||
name='twitch',
|
||||
description="Link a twitch account to this community."
|
||||
)
|
||||
@appcmds.guild_only()
|
||||
@appcmds.default_permissions(manage_guild=True)
|
||||
async def comm_link_twitch_cmd(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
assert ctx.guild is not None
|
||||
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
|
||||
if not ctx.author.guild_permissions.manage_guild:
|
||||
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
|
||||
return
|
||||
|
||||
# Ask the user to go through auth to get their userid
|
||||
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||
flow = await auth_cog.start_auth(
|
||||
scopes=[
|
||||
AuthScope.CHAT_EDIT,
|
||||
AuthScope.CHAT_READ,
|
||||
AuthScope.MODERATION_READ,
|
||||
AuthScope.CHANNEL_BOT,
|
||||
]
|
||||
)
|
||||
message = await ctx.reply(
|
||||
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
|
||||
)
|
||||
authrow = await flow.run()
|
||||
await message.edit(
|
||||
content="Authentication Complete! Beginning community profile merge..."
|
||||
)
|
||||
|
||||
# results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||
# if not results:
|
||||
# logger.error(
|
||||
# f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
# )
|
||||
# await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
# return
|
||||
|
||||
# user = results[0]
|
||||
try:
|
||||
user = await first(self.bot.twitch.get_users(user_ids=[str(authrow.userid)]))
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Looking up user {authrow} from Twitch authentication flow raised an error.",
|
||||
exc_info=True
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
if user is None:
|
||||
logger.error(
|
||||
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
# Retrieve author's profile if it exists
|
||||
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
|
||||
|
||||
# Check if the twitch-side user has a profile
|
||||
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||
|
||||
if guild_comm and twitch_comm is None:
|
||||
# All we need to do is attach the twitch row
|
||||
await guild_comm.attach_twitch(user.id)
|
||||
await message.edit(
|
||||
content=f"Successfully linked Twitch channel **{user.display_name}**! There was no community data to merge."
|
||||
)
|
||||
elif twitch_comm and guild_comm is None:
|
||||
# Attach the discord row to the profile
|
||||
await twitch_comm.attach_discord(ctx.guild.id)
|
||||
await message.edit(
|
||||
content=f"Successfully connected to Twitch channel **{user.display_name}**!"
|
||||
)
|
||||
elif twitch_comm is None and guild_comm is None:
|
||||
profile = await Community.create_from_discord(self.bot, ctx.guild)
|
||||
await profile.attach_twitch(user.id)
|
||||
|
||||
await message.edit(
|
||||
content=f"Created a new community for this server and linked Twitch account **{user.display_name}**."
|
||||
)
|
||||
elif guild_comm.communityid == twitch_comm.communityid:
|
||||
await message.edit(
|
||||
content=f"This server is already linked to the Twitch channel **{user.display_name}**!"
|
||||
)
|
||||
else:
|
||||
# Migrate the existing profile data to the new profiles
|
||||
try:
|
||||
results = await self.migrate_community(twitch_comm, guild_comm)
|
||||
except Exception:
|
||||
await ctx.error_reply(
|
||||
"An issue was encountered while merging your community profiles!\n"
|
||||
"Migration rolled back, no data has been lost.\n"
|
||||
"The developer has been notified. Please try again later!"
|
||||
)
|
||||
raise
|
||||
|
||||
content = '\n'.join((
|
||||
"## Connecting Twitch account and merging community profiles...",
|
||||
*results,
|
||||
"**Successfully linked account and merged community data!**"
|
||||
))
|
||||
await message.edit(content=content)
|
||||
123
src/modules/profiles/community.py
Normal file
123
src/modules/profiles/community.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Optional, Self
|
||||
|
||||
import discord
|
||||
|
||||
from meta import LionBot
|
||||
from utils.lib import utc_now
|
||||
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
|
||||
|
||||
|
||||
class Community:
|
||||
def __init__(self, bot: LionBot, community_row):
|
||||
self.bot = bot
|
||||
self.row: ProfileData.CommunityRow = community_row
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
return self.bot.get_cog('ProfileCog')
|
||||
|
||||
@property
|
||||
def data(self) -> ProfileData:
|
||||
return self.cog.data
|
||||
|
||||
@property
|
||||
def communityid(self):
|
||||
return self.row.communityid
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Community communityid={self.communityid} row={self.row}>"
|
||||
|
||||
async def attach_discord(self, guildid: int):
|
||||
"""
|
||||
Attach a new discord guild to this community.
|
||||
Assumes the discord guild is not already associated to a community.
|
||||
"""
|
||||
discord_row = await self.data.DiscordCommunityRow.create(
|
||||
communityid=self.communityid,
|
||||
guildid=guildid
|
||||
)
|
||||
logger.info(
|
||||
f"Attached discord guild {guildid} to community {self!r}"
|
||||
)
|
||||
return discord_row
|
||||
|
||||
async def attach_twitch(self, channelid: str):
|
||||
"""
|
||||
Attach a new Twitch user channel to this community.
|
||||
"""
|
||||
twitch_row = await self.data.TwitchCommunityRow.create(
|
||||
communityid=self.communityid,
|
||||
channelid=str(channelid)
|
||||
)
|
||||
logger.info(
|
||||
f"Attached twitch channel {channelid} to community {self!r}"
|
||||
)
|
||||
return twitch_row
|
||||
|
||||
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
|
||||
"""
|
||||
Fetch the Discord guild rows associated to this community.
|
||||
"""
|
||||
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
|
||||
|
||||
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
|
||||
"""
|
||||
Fetch the Twitch user rows associated to this profile.
|
||||
"""
|
||||
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
|
||||
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
|
||||
if community_row is None:
|
||||
raise ValueError("Provided community_id does not exist.")
|
||||
return cls(bot, community_row)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].communityid)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].communityid)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new empty community with the given initial arguments.
|
||||
|
||||
Communities should usually be created using `create_from_discord` or `create_from_twitch`
|
||||
to correctly setup initial preferences (e.g. name, avatar).
|
||||
"""
|
||||
# Create a new community
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
|
||||
return await cls.fetch(bot, row.communityid)
|
||||
|
||||
@classmethod
|
||||
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new community using the given Discord guild as a base.
|
||||
"""
|
||||
self = await cls.create(bot, **kwargs)
|
||||
await self.attach_discord(guild.id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Twitch channel user as a base.
|
||||
The provided `user` must have an `id` attribute.
|
||||
"""
|
||||
self = await cls.create(bot, **kwargs)
|
||||
await self.attach_twitch(str(user.id))
|
||||
return self
|
||||
158
src/modules/profiles/data.py
Normal file
158
src/modules/profiles/data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class ProfileData(Registry):
|
||||
class UserProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_profiles(
|
||||
profileid SERIAL PRIMARY KEY,
|
||||
nickname TEXT,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_profiles'
|
||||
_cache_ = {}
|
||||
|
||||
profileid = Integer(primary=True)
|
||||
nickname = String()
|
||||
migrated = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
|
||||
class DiscordProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE profiles_discord(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||
"""
|
||||
_tablename_ = 'profiles_discord'
|
||||
_cache_ = {}
|
||||
|
||||
linkid = Integer(primary=True)
|
||||
profileid = Integer()
|
||||
userid = Integer()
|
||||
created_at = Integer()
|
||||
|
||||
@classmethod
|
||||
async def fetch_profile(cls, profileid: int):
|
||||
rows = await cls.fetch_where(profiled=profileid)
|
||||
return next(rows, None)
|
||||
|
||||
|
||||
class TwitchProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE profiles_twitch(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||
"""
|
||||
_tablename_ = 'profiles_twitch'
|
||||
_cache_ = {}
|
||||
|
||||
linkid = Integer(primary=True)
|
||||
profileid = Integer()
|
||||
userid = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_profile(cls, profileid: int):
|
||||
rows = await cls.fetch_where(profiled=profileid)
|
||||
return next(rows, None)
|
||||
|
||||
class CommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities(
|
||||
communityid SERIAL PRIMARY KEY,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities'
|
||||
_cache_ = {}
|
||||
|
||||
communityid = Integer(primary=True)
|
||||
migrated = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
class DiscordCommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities_discord(
|
||||
guildid BIGINT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities_discord'
|
||||
_cache_ = {}
|
||||
|
||||
guildid = Integer(primary=True)
|
||||
communityid = Integer()
|
||||
linked_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_community(cls, communityid: int):
|
||||
rows = await cls.fetch_where(communityd=communityid)
|
||||
return next(rows, None)
|
||||
|
||||
class TwitchCommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities_twitch(
|
||||
channelid TEXT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities_twitch'
|
||||
_cache_ = {}
|
||||
|
||||
channelid = String(primary=True)
|
||||
communityid = Integer()
|
||||
linked_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_community(cls, communityid: int):
|
||||
rows = await cls.fetch_where(communityd=communityid)
|
||||
return next(rows, None)
|
||||
|
||||
class CommunityMemberRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE community_members(
|
||||
memberid SERIAL PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||
"""
|
||||
_tablename_ = 'community_members'
|
||||
_cache_ = {}
|
||||
|
||||
memberid = Integer(primary=True)
|
||||
communityid = Integer()
|
||||
profileid = Integer()
|
||||
created_at = Timestamp()
|
||||
138
src/modules/profiles/profile.py
Normal file
138
src/modules/profiles/profile.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Optional, Self
|
||||
|
||||
import discord
|
||||
|
||||
from meta import LionBot
|
||||
from utils.lib import utc_now
|
||||
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
|
||||
|
||||
|
||||
class UserProfile:
|
||||
def __init__(self, bot: LionBot, profile_row):
|
||||
self.bot = bot
|
||||
self.profile_row: ProfileData.UserProfileRow = profile_row
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
return self.bot.get_cog('ProfileCog')
|
||||
|
||||
@property
|
||||
def data(self) -> ProfileData:
|
||||
return self.cog.data
|
||||
|
||||
@property
|
||||
def profileid(self):
|
||||
return self.profile_row.profileid
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
|
||||
|
||||
async def get_name(self) -> Optional[str]:
|
||||
return self.profile_row.nickname
|
||||
|
||||
async def attach_discord(self, userid: int):
|
||||
"""
|
||||
Attach a new discord user to this profile.
|
||||
Assumes the discord user does not itself have a profile.
|
||||
"""
|
||||
discord_row = await self.data.DiscordProfileRow.create(
|
||||
profileid=self.profileid,
|
||||
userid=userid
|
||||
)
|
||||
logger.info(
|
||||
f"Attached discord user {userid} to profile {self!r}"
|
||||
)
|
||||
return discord_row
|
||||
|
||||
async def attach_twitch(self, userid: str):
|
||||
"""
|
||||
Attach a new Twitch user to this profile.
|
||||
"""
|
||||
twitch_row = await self.data.TwitchProfileRow.create(
|
||||
profileid=self.profileid,
|
||||
userid=userid
|
||||
)
|
||||
logger.info(
|
||||
f"Attached twitch user {userid} to profile {self!r}"
|
||||
)
|
||||
return twitch_row
|
||||
|
||||
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
|
||||
"""
|
||||
Fetch the Discord accounts associated to this profile.
|
||||
"""
|
||||
return await self.data.DiscordProfileRow.fetch_where(
|
||||
profileid=self.profileid
|
||||
).order_by(
|
||||
'created_at'
|
||||
)
|
||||
|
||||
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
|
||||
"""
|
||||
Fetch the Twitch accounts associated to this profile.
|
||||
"""
|
||||
return await self.data.TwitchProfileRow.fetch_where(
|
||||
profileid=self.profileid
|
||||
).order_by(
|
||||
'created_at'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
|
||||
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
|
||||
if profile_row is None:
|
||||
raise ValueError("Provided profile_id does not exist.")
|
||||
return cls(bot, profile_row)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].profileid)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].profileid)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new empty profile with the given initial arguments.
|
||||
|
||||
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
|
||||
to correctly setup initial profile preferences (e.g. name, avatar).
|
||||
"""
|
||||
# Create a new profile
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
profile_row = await data.UserProfileRow.create(created_at=utc_now())
|
||||
profile = await cls.fetch(bot, profile_row.profileid)
|
||||
return profile
|
||||
|
||||
@classmethod
|
||||
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Discord user as a base.
|
||||
"""
|
||||
kwargs.setdefault('nickname', user.name)
|
||||
profile = await cls.create(bot, **kwargs)
|
||||
await profile.attach_discord(user.id)
|
||||
return profile
|
||||
|
||||
@classmethod
|
||||
async def create_from_twitch(cls, bot: LionBot, user, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Twitch user as a base.
|
||||
|
||||
Assumes the provided `user` has `id` and `name` attributes.
|
||||
"""
|
||||
kwargs.setdefault('nickname', user.name)
|
||||
profile = await cls.create(bot, **kwargs)
|
||||
await profile.attach_twitch(str(user.id))
|
||||
return profile
|
||||
1
src/modules/profiles/ui/__init__.py
Normal file
1
src/modules/profiles/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .twitchlink import TwitchLinkStatic, TwitchLinkFlow
|
||||
337
src/modules/profiles/ui/twitchlink.py
Normal file
337
src/modules/profiles/ui/twitchlink.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
UI Views for Twitch linkage.
|
||||
|
||||
- Persistent view with interaction-button to enter link flow.
|
||||
- We don't store the view, but we listen to interaction button id.
|
||||
- Command to enter link flow.
|
||||
|
||||
For link flow, send ephemeral embed with instructions and what to expect, with link button below.
|
||||
After auth is granted through OAuth flow (or if not granted, e.g. on timeout or failure)
|
||||
edit the embed to reflect auth situation.
|
||||
|
||||
If migration occurred, add the migration text as a field to the embed.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord.ui.button import button, Button
|
||||
from discord.enums import ButtonStyle
|
||||
from twitchAPI.helper import first
|
||||
|
||||
from meta import LionBot
|
||||
from meta.errors import SafeCancellation
|
||||
from meta.logger import log_wrap
|
||||
|
||||
from utils.ui import MessageUI
|
||||
from utils.lib import MessageArgs, utc_now
|
||||
from utils.ui.leo import LeoUI
|
||||
|
||||
from modules.profiles.profile import UserProfile
|
||||
|
||||
import brand
|
||||
|
||||
from .. import logger
|
||||
|
||||
class TwitchLinkStatic(LeoUI):
|
||||
"""
|
||||
Static UI whose only job is to display a persistent button
|
||||
to ask people to connect their twitch account.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._embed: Optional[discord.Embed] = None
|
||||
print("INITIALISATION")
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
return True
|
||||
|
||||
@property
|
||||
def embed(self) -> discord.Embed:
|
||||
"""
|
||||
This is the persistent message people will see with the button that starts the Oauth flow.
|
||||
|
||||
Not sure what this should actually say, or whether it should be customisable via command.
|
||||
|
||||
:TODO-MARKER:
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title="Link your Twitch account!",
|
||||
description=(
|
||||
"To participate in the Dreamspace Adventure Game :TM:, "
|
||||
"please start by pressing the button below to begin the login flow with Twitch!"
|
||||
),
|
||||
colour=brand.ACCENT_COLOUR,
|
||||
)
|
||||
return embed
|
||||
|
||||
@embed.setter
|
||||
def embed(self, value):
|
||||
self._embed = value
|
||||
|
||||
@button(label="Connect", custom_id="BTN-LINK-TWITCH", style=ButtonStyle.green, emoji='🔗')
|
||||
@log_wrap(action="link-twitch-btn")
|
||||
async def button_linker(self, interaction: discord.Interaction, btn: Button):
|
||||
# Here we just reply to the interaction with the AuthFlow UI
|
||||
# TODO
|
||||
print("RESPONDING")
|
||||
flowui = TwitchLinkFlow(interaction.client, interaction.user, callerid=interaction.user.id)
|
||||
await flowui.run(interaction)
|
||||
await flowui.wait()
|
||||
|
||||
|
||||
class FlowState(IntEnum):
|
||||
SETUP = -1
|
||||
WAITING = 0
|
||||
|
||||
CANCELLED = 1
|
||||
TIMEOUT = 2
|
||||
ERRORED = 3
|
||||
|
||||
WORKING = 9
|
||||
DONE = 10
|
||||
|
||||
|
||||
class TwitchLinkFlow(MessageUI):
|
||||
def __init__(self, bot: LionBot, caller: discord.User | discord.Member, *args, **kwargs):
|
||||
kwargs.setdefault('callerid', caller.id)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.bot = bot
|
||||
|
||||
self._auth_task = None
|
||||
self._stage: FlowState = FlowState.SETUP
|
||||
self.flow = None
|
||||
self.authrow = None
|
||||
self.user = caller
|
||||
|
||||
self._info = None
|
||||
self._migration_details = None
|
||||
|
||||
# ----- UI API -----
|
||||
async def run(self, interaction: discord.Interaction, **kwargs):
|
||||
await interaction.response.defer(ephemeral=True, thinking=True)
|
||||
await self._start_flow()
|
||||
await self.draw(interaction, **kwargs)
|
||||
if self._stage is FlowState.ERRORED:
|
||||
# This can happen if starting the flow failed
|
||||
await self.close()
|
||||
|
||||
@log_wrap(action="start-twitch-flow-ui")
|
||||
async def _start_flow(self):
|
||||
logger.info(f"Starting twitch authentication flow for {self.user}")
|
||||
try:
|
||||
self.flow = await self.bot.get_cog('TwitchAuthCog').start_auth()
|
||||
except aiohttp.ClientError:
|
||||
self._stage = FlowState.ERRORED
|
||||
self._info = (
|
||||
"Could not establish a connection to the authentication server! "
|
||||
"Please try again later~"
|
||||
)
|
||||
logger.exception("Unexpected exception while starting authentication flow!", exc_info=True)
|
||||
else:
|
||||
self._stage = FlowState.WAITING
|
||||
self._auth_task = asyncio.create_task(self._auth_flow())
|
||||
|
||||
@log_wrap(action="run-twitch-flow-ui")
|
||||
async def _auth_flow(self):
|
||||
"""
|
||||
Run the flow and wait for a timeout, cancellation, or callback.
|
||||
Update the message accordingly.
|
||||
"""
|
||||
assert self.flow is not None
|
||||
try:
|
||||
# TODO: Cancel this in cleanup
|
||||
authrow = await asyncio.wait_for(self.flow.run(), timeout=60)
|
||||
except asyncio.TimeoutError:
|
||||
self._stage = FlowState.TIMEOUT
|
||||
# Link Timed Out!
|
||||
self._info = (
|
||||
"We didn't receive a response so we closed the uplink "
|
||||
"to keep your account safe! If you still want to connect, please try again!"
|
||||
)
|
||||
await self.refresh()
|
||||
await self.close()
|
||||
except asyncio.CancelledError:
|
||||
# Presumably the user exited or the bot is shutting down.
|
||||
# Not safe to edit the message, but try and cleanup
|
||||
await self.close()
|
||||
except SafeCancellation as e:
|
||||
logger.info("User or server cancelled authentication flow: ", exc_info=True)
|
||||
# Uplink Cancelled!
|
||||
self._info = (
|
||||
f"We couldn't complete the uplink!\nReason:*{e.msg}*"
|
||||
)
|
||||
self._stage = FlowState.CANCELLED
|
||||
await self.refresh()
|
||||
await self.close()
|
||||
except Exception:
|
||||
logger.exception("Something unexpected went wrong while running the flow!")
|
||||
else:
|
||||
self._stage = FlowState.WORKING
|
||||
self._info = (
|
||||
"Authentication complete! Connecting your Dreamspace account ...."
|
||||
)
|
||||
self.authrow = authrow
|
||||
await self.refresh()
|
||||
await self._link_twitch(str(authrow.userid))
|
||||
await self.refresh()
|
||||
await self.close()
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
if self._auth_task and not self._auth_task.cancelled():
|
||||
self._auth_task.cancel()
|
||||
|
||||
async def _link_twitch(self, twitch_id: str):
|
||||
"""
|
||||
Link the caller's profile to the given twitch_id.
|
||||
|
||||
Performs migration if needed.
|
||||
"""
|
||||
try:
|
||||
twitch_user = await first(self.bot.twitch.get_users(user_ids=[twitch_id]))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Looking up user {self.authrow} from Twitch authentication flow raised an error."
|
||||
)
|
||||
self._stage = FlowState.ERRORED
|
||||
self._info = "Failed to look up your user details from Twitch! Please try again later."
|
||||
return
|
||||
|
||||
if twitch_user is None:
|
||||
logger.error(
|
||||
f"User {self.authrow} obtained from Twitch authentication does not exist."
|
||||
)
|
||||
self._stage = FlowState.ERRORED
|
||||
self._info = "Authentication failed! Please try again later."
|
||||
return
|
||||
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
userid = self.user.id
|
||||
|
||||
caller_profile = await UserProfile.fetch_from_discordid(self.bot, userid)
|
||||
twitch_profile = await UserProfile.fetch_from_twitchid(self.bot, twitch_id)
|
||||
|
||||
succ_info = (
|
||||
f"Successfully established uplink to your Twitch account **{twitch_user.display_name}** "
|
||||
"and transferred dreamspace data! Happy adventuring, and watch out for the grue~"
|
||||
)
|
||||
# ::TODO-MARKER::
|
||||
|
||||
if twitch_profile is None:
|
||||
if caller_profile is None:
|
||||
# Neither profile exists
|
||||
profile = await UserProfile.create_from_discord(self.bot, self.user)
|
||||
await profile.attach_twitch(twitch_id)
|
||||
|
||||
self._stage = FlowState.DONE
|
||||
self._info = succ_info
|
||||
else:
|
||||
await caller_profile.attach_twitch(twitch_id)
|
||||
|
||||
self._stage = FlowState.DONE
|
||||
self._info = succ_info
|
||||
else:
|
||||
if caller_profile is None:
|
||||
await twitch_profile.attach_discord(self.user.id)
|
||||
|
||||
self._stage = FlowState.DONE
|
||||
self._info = succ_info
|
||||
elif twitch_profile.profileid == caller_profile.profileid:
|
||||
self._stage = FlowState.CANCELLED
|
||||
self._info = (
|
||||
f"The Twitch account **{twitch_user.display_name}** is already linked to your profile!"
|
||||
)
|
||||
else:
|
||||
# In this case we have conflicting profiles we need to migrate
|
||||
try:
|
||||
results = await profiles.migrate_profile(twitch_profile, caller_profile)
|
||||
except Exception:
|
||||
self._stage = FlowState.ERRORED
|
||||
self._info = (
|
||||
"An issue was encountered while merging your account profiles! "
|
||||
"The migration was rolled back, and not data has been lost.\n"
|
||||
"The developer has been notified, please try again later!"
|
||||
)
|
||||
logger.exception(f"Failed to migrate profiles {twitch_profile=} to {caller_profile=}")
|
||||
else:
|
||||
self._stage = FlowState.DONE
|
||||
self._info = succ_info
|
||||
self._migration_details = '\n'.join(results)
|
||||
logger.info(
|
||||
f"Migrated {twitch_profile=} to {caller_profile}. Info: {self._migration_details}"
|
||||
)
|
||||
|
||||
# ----- UI Flow -----
|
||||
async def make_message(self) -> MessageArgs:
|
||||
if self._stage is FlowState.SETUP:
|
||||
raise ValueError("Making message before flow initialisation!")
|
||||
assert self.flow is not None
|
||||
|
||||
if self._stage is FlowState.WAITING:
|
||||
# Message should be the initial request page
|
||||
dur = discord.utils.format_dt(utc_now() + timedelta(seconds=60), style='R')
|
||||
|
||||
title = "Press the button to login!"
|
||||
desc = (
|
||||
"We have generated a custom secure link for you to connect your Twitch profile! "
|
||||
"Press the button below and accept the connection in your browser, "
|
||||
"and we will begin the transfer!\n"
|
||||
f"(Note: The link expires {dur})"
|
||||
)
|
||||
colour = brand.ACCENT_COLOUR
|
||||
elif self._stage is FlowState.CANCELLED:
|
||||
# Show cancellation message
|
||||
# Show 'you can close this'
|
||||
title = "Uplink Cancelled!"
|
||||
desc = self._info
|
||||
colour = discord.Colour.brand_red()
|
||||
elif self._stage is FlowState.TIMEOUT:
|
||||
title = "Link Timed Out"
|
||||
desc = self._info
|
||||
colour = discord.Colour.brand_red()
|
||||
elif self._stage is FlowState.ERRORED:
|
||||
title = "Something went wrong!"
|
||||
desc = self._info
|
||||
colour = discord.Colour.brand_red()
|
||||
elif self._stage is FlowState.WORKING:
|
||||
# We've received the auth, we are now doing migration
|
||||
title = "Establishing Connection"
|
||||
desc = self._info
|
||||
colour = brand.ACCENT_COLOUR
|
||||
elif self._stage is FlowState.DONE:
|
||||
title = "Success!"
|
||||
desc = self._info
|
||||
colour = discord.Colour.brand_green()
|
||||
else:
|
||||
raise ValueError(f"Invalid stage value {self._stage}")
|
||||
|
||||
embed = discord.Embed(title=title, description=desc, colour=colour, timestamp=utc_now())
|
||||
if self._migration_details:
|
||||
embed.add_field(
|
||||
name="Profile migration details",
|
||||
value=self._migration_details
|
||||
)
|
||||
return MessageArgs(embed=embed)
|
||||
|
||||
async def refresh_layout(self):
|
||||
# If we haven't received the auth callback yet, make the flow link button
|
||||
if self.flow is None:
|
||||
raise ValueError("Refreshing before flow initialisation!")
|
||||
|
||||
if self._stage <= FlowState.WAITING:
|
||||
flow_link = self.flow.auth.return_auth_url()
|
||||
button = Button(
|
||||
style=ButtonStyle.link,
|
||||
url=flow_link,
|
||||
label="Login With Twitch"
|
||||
)
|
||||
self.set_layout((button,))
|
||||
else:
|
||||
self.set_layout(())
|
||||
|
||||
async def reload(self):
|
||||
pass
|
||||
16
src/routes/__init__.py
Normal file
16
src/routes/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .stamps import routes as stamp_routes
|
||||
from .documents import routes as doc_routes
|
||||
from .users import routes as user_routes
|
||||
from .specimens import routes as spec_routes
|
||||
from .transactions import routes as txn_routes
|
||||
from .events import routes as event_routes
|
||||
from .lib import dbvar, datamodelsv, profiledatav
|
||||
|
||||
|
||||
def register_routes(router):
|
||||
router.add_routes(stamp_routes)
|
||||
router.add_routes(doc_routes)
|
||||
router.add_routes(user_routes)
|
||||
router.add_routes(spec_routes)
|
||||
router.add_routes(event_routes)
|
||||
router.add_routes(txn_routes)
|
||||
370
src/routes/documents.py
Normal file
370
src/routes/documents.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
- `/documents` with `POST, GET`
|
||||
- `/documents/{document_id}` with `GET`, `PATCH`, `DELETE`
|
||||
- `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set.
|
||||
"""
|
||||
import logging
|
||||
import binascii
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple, Optional, Self, TypedDict, Unpack, reveal_type, List
|
||||
from aiohttp import web
|
||||
import discord
|
||||
from data import Condition, condition
|
||||
from data.queries import JOINTYPE
|
||||
from datamodels import DataModel
|
||||
from utils.lib import MessageArgs, tabulate
|
||||
|
||||
from .lib import ModelField, datamodelsv, event_log
|
||||
from .stamps import Stamp, StampCreateParams, StampEditParams, StampPayload
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocPayload(TypedDict):
|
||||
document_id: int
|
||||
document_data: str
|
||||
seal: int
|
||||
created_at: str
|
||||
metadata: Optional[str]
|
||||
stamps: List[StampPayload]
|
||||
|
||||
|
||||
class DocCreateParamsReq(TypedDict, total=True):
|
||||
document_data: str
|
||||
seal: int
|
||||
|
||||
|
||||
class DocCreateParams(DocCreateParamsReq, total=False):
|
||||
metadata: Optional[str]
|
||||
stamps: List[StampCreateParams]
|
||||
|
||||
|
||||
class DocEditParams(TypedDict, total=False):
|
||||
document_data: str
|
||||
seal: int
|
||||
metadata: Optional[str]
|
||||
stamps: List[StampCreateParams]
|
||||
|
||||
|
||||
fields = [
|
||||
ModelField('document_id', int, False, False, False),
|
||||
ModelField('document_data', str, True, True, True),
|
||||
ModelField('seal', int, True, True, True),
|
||||
ModelField('created_at', str, False, False, False),
|
||||
ModelField('metadata', Optional[str], False, True, True),
|
||||
ModelField('stamps', List[StampCreateParams], False, True, True),
|
||||
]
|
||||
req_fields = {field.name for field in fields if field.required}
|
||||
edit_fields = {field.name for field in fields if field.can_edit}
|
||||
create_fields = {field.name for field in fields if field.can_create}
|
||||
|
||||
|
||||
class Document:
|
||||
def __init__(self, app: web.Application, row: DataModel.Document):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
async def validate_create_params(cls, params):
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to document creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"Document params missing required key '{missing}'")
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, document_id: int) -> Optional[Self]:
|
||||
data = app[datamodelsv]
|
||||
row = await data.Document.fetch(document_id)
|
||||
return cls(app, row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
document_id: Optional[int] = None,
|
||||
seal: Optional[int] = None,
|
||||
created_before: Optional[str] = None,
|
||||
created_after: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
stamp_type: Optional[str] = None,
|
||||
) -> List[Self]:
|
||||
data = app[datamodelsv]
|
||||
Doc = data.Document
|
||||
|
||||
conds = []
|
||||
if document_id is not None:
|
||||
conds.append(Doc.document_id == int(document_id))
|
||||
if seal is not None:
|
||||
conds.append(Doc.seal == int(seal))
|
||||
if created_before is not None:
|
||||
cbefore = datetime.fromisoformat(created_before)
|
||||
conds.append(Doc.created_at <= cbefore)
|
||||
if created_after is not None:
|
||||
cafter = datetime.fromisoformat(created_after)
|
||||
conds.append(Doc.created_at >= cafter)
|
||||
if metadata is not None:
|
||||
conds.append(Doc.metadata == metadata)
|
||||
|
||||
query = data.Document.table.fetch_rows_where(*conds)
|
||||
# results = await query
|
||||
|
||||
# query = data.Document.table.select_where(*conds)
|
||||
if stamp_type is not None:
|
||||
query.join('document_stamps', using=('document_id',), join_type=JOINTYPE.LEFT)
|
||||
query.join(
|
||||
'stamp_types',
|
||||
on=(data.DocumentStamp.stamp_type == data.StampType.stamp_type_id)
|
||||
)
|
||||
query.where(data.StampType.stamp_type_name == stamp_type)
|
||||
query.select(docid = "DISTINCT(document_id)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
ids = [result['docid'] for result in results]
|
||||
if ids:
|
||||
rows = await data.Document.table.fetch_rows_where(
|
||||
document_id=ids
|
||||
)
|
||||
else:
|
||||
rows = []
|
||||
else:
|
||||
rows = await query
|
||||
|
||||
return [cls(app, row) for row in sorted(rows, key=lambda row:row.created_at)]
|
||||
|
||||
@classmethod
|
||||
async def create(cls, app: web.Application, **kwargs: Unpack[DocCreateParams]) -> Self:
|
||||
data = app[datamodelsv]
|
||||
|
||||
document_data = kwargs['document_data']
|
||||
seal = kwargs['seal']
|
||||
stamp_params = kwargs.get('stamps', [])
|
||||
metadata = kwargs.get('metadata')
|
||||
|
||||
# Build the document first
|
||||
row = await data.Document.create(
|
||||
document_data=document_data,
|
||||
seal=seal,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Then build the stamps
|
||||
stamps = []
|
||||
for stampdata in stamp_params:
|
||||
stampdata.setdefault('document_id', row.document_id)
|
||||
stamp = await Stamp.create(app, **stampdata)
|
||||
stamps.append(stamp)
|
||||
|
||||
self = cls(app, row)
|
||||
# await self.log_create()
|
||||
return self
|
||||
|
||||
async def get_stamps(self) -> List[Stamp]:
|
||||
stamprows = await self.data.DocumentStamp.table.fetch_rows_where(document_id=self.row.document_id).order_by('stamp_id')
|
||||
return [Stamp(self.app, row) for row in stamprows]
|
||||
|
||||
async def log_create(self, with_image=True):
|
||||
args = await self.event_log_args(with_image=with_image)
|
||||
args.kwargs['embed'].title = f"Document #{self.row.document_id} Created!"
|
||||
try:
|
||||
await event_log(**args.send_args)
|
||||
except discord.HTTPException:
|
||||
if with_image:
|
||||
# Try again without the image in case that was the issue
|
||||
await self.log_create(with_image=False)
|
||||
|
||||
async def log_edit(self, with_image=True):
|
||||
args = await self.event_log_args(with_image=with_image)
|
||||
args.kwargs['embed'].title = f"Document #{self.row.document_id} Updated!"
|
||||
try:
|
||||
await event_log(**args.send_args)
|
||||
except discord.HTTPException:
|
||||
if with_image:
|
||||
# Try again without the image in case that was the issue
|
||||
await self.log_create(with_image=False)
|
||||
|
||||
async def event_log_args(self, with_image=True) -> MessageArgs:
|
||||
desc = '\n'.join(await self.tabulate())
|
||||
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
|
||||
embed.set_footer(text='Created At')
|
||||
args: dict = {'embed': embed}
|
||||
|
||||
if with_image:
|
||||
try:
|
||||
imagedata = self.row.to_bytes()
|
||||
imagedata.seek(0)
|
||||
embed.set_image(url='attachment://document.png')
|
||||
args['files'] = [discord.File(imagedata, "document.png")]
|
||||
except binascii.Error:
|
||||
# Could not decode base64
|
||||
embed.add_field(name='Image', value="Could not decode document data!")
|
||||
else:
|
||||
embed.add_field(name='Image', value="Failed to send image!")
|
||||
return MessageArgs(**args)
|
||||
|
||||
async def tabulate(self):
|
||||
"""
|
||||
Present the Document as a discord-readable table.
|
||||
"""
|
||||
stamps = await self.get_stamps()
|
||||
typnames = []
|
||||
if stamps:
|
||||
typs = {stamp.row.stamp_type for stamp in stamps}
|
||||
for typ in typs:
|
||||
# Stamp types should be cached so this isn't expensive
|
||||
typrow = await self.data.StampType.fetch(typ)
|
||||
typnames.append(typrow.stamp_type_name)
|
||||
|
||||
table = {
|
||||
'document_id': f"`{self.row.document_id}`",
|
||||
'seal': str(self.row.seal),
|
||||
'metadata': f"`{self.row.metadata}`" if self.row.metadata else "No metadata",
|
||||
'stamps': ', '.join(f"`{name}`" for name in typnames) if typnames else "No stamps",
|
||||
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||
}
|
||||
return tabulate(*table.items())
|
||||
|
||||
|
||||
async def prepare(self) -> DocPayload:
|
||||
stamps = await self.get_stamps()
|
||||
|
||||
results: DocPayload = {
|
||||
'document_id': self.row.document_id,
|
||||
'document_data': self.row.document_data,
|
||||
'seal': self.row.seal,
|
||||
'created_at': self.row.created_at.isoformat(),
|
||||
'metadata': self.row.metadata,
|
||||
'stamps': [await stamp.prepare() for stamp in stamps]
|
||||
}
|
||||
return results
|
||||
|
||||
async def edit(self, **kwargs: Unpack[DocEditParams]):
|
||||
data = self.data
|
||||
row = self.row
|
||||
# Update the row data
|
||||
# If stamps are given, delete the existing ones
|
||||
# Then write in the new ones.
|
||||
update_args = {}
|
||||
for key in {'document_data', 'seal', 'metadata'}:
|
||||
if key in kwargs:
|
||||
update_args[key] = kwargs[key]
|
||||
if update_args:
|
||||
await self.row.update(**update_args)
|
||||
|
||||
# TODO: Should really be in a transaction
|
||||
# Actually each handler should be in a transaction
|
||||
if new_stamps := kwargs.get('stamps', []):
|
||||
await self.data.DocumentStamp.table.delete_where(document_id=self.row.document_id)
|
||||
for stampdata in new_stamps:
|
||||
stampdata.setdefault('document_id', row.document_id)
|
||||
await Stamp.create(self.app, **stampdata)
|
||||
await self.log_edit()
|
||||
|
||||
async def delete(self) -> DocPayload:
|
||||
payload = await self.prepare()
|
||||
await self.row.delete()
|
||||
return payload
|
||||
|
||||
|
||||
@routes.view('/documents')
|
||||
@routes.view('/documents/')
|
||||
class DocumentsView(web.View):
|
||||
async def get(self):
|
||||
request = self.request
|
||||
filter_params = {}
|
||||
keys = [
|
||||
'document_id',
|
||||
'seal',
|
||||
'created_before',
|
||||
'created_after',
|
||||
'metadata',
|
||||
'stamp_type',
|
||||
]
|
||||
for key in keys:
|
||||
if key in request.query:
|
||||
filter_params[key] = request.query[key]
|
||||
elif key in request:
|
||||
filter_params[key] = request[key]
|
||||
|
||||
documents = await Document.query(request.app, **filter_params)
|
||||
payload = [await doc.prepare() for doc in documents]
|
||||
return web.json_response(payload)
|
||||
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
for key in create_fields:
|
||||
if key in request:
|
||||
params.setdefault(key, request[key])
|
||||
|
||||
await Document.validate_create_params(params)
|
||||
|
||||
document = await Document.create(self.request.app, **params)
|
||||
payload = await document.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
@routes.view('/documents/{document_id}')
|
||||
@routes.view('/documents/{document_id}/')
|
||||
class DocumentView(web.View):
|
||||
|
||||
async def resolve_document(self):
|
||||
request = self.request
|
||||
document_id = request.match_info['document_id']
|
||||
document = await Document.fetch_from_id(request.app, int(document_id))
|
||||
if document is None:
|
||||
raise web.HTTPNotFound(text="No document exists with the given ID.")
|
||||
return document
|
||||
|
||||
async def get(self):
|
||||
doc = await self.resolve_document()
|
||||
payload = await doc.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
doc = await self.resolve_document()
|
||||
params = await self.request.json()
|
||||
|
||||
edit_data = {}
|
||||
for key, value in params.items():
|
||||
if key not in edit_fields:
|
||||
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Document!")
|
||||
edit_data[key] = value
|
||||
|
||||
for key in edit_fields:
|
||||
if key in self.request:
|
||||
edit_data.setdefault(key, self.request[key])
|
||||
|
||||
await doc.edit(**edit_data)
|
||||
payload = await doc.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def delete(self):
|
||||
doc = await self.resolve_document()
|
||||
payload = await doc.delete()
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
# We have one prefix route, /documents/{document_id}/stamps
|
||||
@routes.route('*', "/documents/{document_id}{tail:/stamps}")
|
||||
@routes.route('*', "/documents/{document_id}{tail:/stamps/.*}")
|
||||
async def document_stamps_route(request: web.Request):
|
||||
document_id = int(request.match_info['document_id'])
|
||||
document = await Document.fetch_from_id(request.app, document_id)
|
||||
if document is None:
|
||||
raise web.HTTPNotFound(text="No document exists with the given ID.")
|
||||
|
||||
new_path = request.match_info['tail']
|
||||
logger.info(f"Redirecting {request=} to {new_path=} and setting {document_id=}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['document_id'] = document_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
match_info.current_app = request.app
|
||||
new_request._match_info = match_info
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
raise web.HTTPNotFound()
|
||||
|
||||
|
||||
505
src/routes/events.py
Normal file
505
src/routes/events.py
Normal file
@@ -0,0 +1,505 @@
|
||||
import binascii
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
|
||||
from aiohttp import web
|
||||
import discord
|
||||
from data import Condition, condition
|
||||
from data.conditions import NULL
|
||||
from data.queries import JOINTYPE
|
||||
from datamodels import DataModel, EventType
|
||||
|
||||
from modules.profiles.data import ProfileData
|
||||
from utils.lib import MessageArgs, tabulate
|
||||
|
||||
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
|
||||
from .specimens import Specimen, SpecimenPayload
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Event:
|
||||
def __init__(self, app: web.Application, row: DataModel.EventDetails):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, event_id: int):
|
||||
data = app[datamodelsv]
|
||||
row = await data.EventDetails.fetch(int(event_id))
|
||||
return cls(app, row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
event_id: Optional[str] = None,
|
||||
document_id: Optional[str] = None,
|
||||
document_seal: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
occurred_before: Optional[str] = None,
|
||||
occurred_after: Optional[str] = None,
|
||||
created_before: Optional[str] = None,
|
||||
created_after: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
) -> List[Self]:
|
||||
data = app[datamodelsv]
|
||||
EventD = data.EventDetails
|
||||
|
||||
conds = []
|
||||
if event_id is not None:
|
||||
conds.append(EventD.event_id == int(event_id))
|
||||
if document_id is not None:
|
||||
conds.append(EventD.document_id == int(document_id))
|
||||
if document_seal is not None:
|
||||
conds.append(EventD.document_seal == int(document_seal))
|
||||
if user_id is not None:
|
||||
conds.append(EventD.user_id == int(user_id))
|
||||
if user_name is not None:
|
||||
conds.append(EventD.user_name == user_name)
|
||||
if created_before is not None:
|
||||
cbefore = datetime.fromisoformat(created_before)
|
||||
conds.append(EventD.created_at <= cbefore)
|
||||
if created_after is not None:
|
||||
cafter = datetime.fromisoformat(created_after)
|
||||
conds.append(EventD.created_at >= cafter)
|
||||
if occurred_before is not None:
|
||||
before = datetime.fromisoformat(occurred_before)
|
||||
conds.append(EventD.occurred_at <= before)
|
||||
if occurred_after is not None:
|
||||
after = datetime.fromisoformat(occurred_after)
|
||||
conds.append(EventD.occurred_at >= after)
|
||||
if event_type is not None:
|
||||
ekey = (event_type.lower().strip(),)
|
||||
if ekey not in [e.value for e in EventType]:
|
||||
raise web.HTTPBadRequest(text=f"Unknown event type '{event_type}'")
|
||||
conds.append(EventD.event_type == EventType(ekey))
|
||||
|
||||
rows = await EventD.fetch_where(*conds).order_by(EventD.occurred_at)
|
||||
return [cls(app, row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def validate_create_params(cls, params):
|
||||
if 'event_type' not in params:
|
||||
raise web.HTTPBadRequest(text="Event creation missing required field 'event_type'.")
|
||||
|
||||
ekey = (params['event_type'].lower().strip(),)
|
||||
if ekey not in [e.value for e in EventType]:
|
||||
raise web.HTTPBadRequest(text=f"Unknown event type '{params['event_type']}'")
|
||||
event_type = EventType(ekey)
|
||||
|
||||
req_fields = {
|
||||
'user_name', 'occurred_at', 'event_type',
|
||||
}
|
||||
other_fields = {
|
||||
'document_id', 'document',
|
||||
'user_id', 'user',
|
||||
}
|
||||
|
||||
if 'user_id' not in params and 'user' not in params:
|
||||
raise web.HTTPBadRequest(text="One of 'user_id' or 'user' must be supplied to create Event.")
|
||||
|
||||
match event_type:
|
||||
case EventType.PLAIN:
|
||||
req_fields.add('message')
|
||||
case EventType.SUBSCRIBER:
|
||||
req_fields.add('tier')
|
||||
req_fields.add('subscribed_length')
|
||||
other_fields.add('message')
|
||||
case EventType.CHEER:
|
||||
req_fields.add('amount')
|
||||
other_fields.add('cheer_type')
|
||||
other_fields.add('message')
|
||||
case EventType.RAID:
|
||||
req_fields.add('viewer_count')
|
||||
|
||||
create_fields = req_fields.union(other_fields)
|
||||
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to {event_type} event creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"{event_type} Event params missing required key '{missing}'")
|
||||
|
||||
|
||||
@classmethod
|
||||
async def create(cls, app: web.Application, **kwargs):
|
||||
data = app[datamodelsv]
|
||||
# EventD = data.EventDetails
|
||||
|
||||
ekey = (kwargs['event_type'].lower().strip(),)
|
||||
if ekey not in [e.value for e in EventType]:
|
||||
raise web.HTTPBadRequest(text=f"Unknown event type '{kwargs['event_type']}'")
|
||||
event_type = EventType(ekey)
|
||||
|
||||
params = {}
|
||||
typparams = {}
|
||||
|
||||
match event_type:
|
||||
case EventType.PLAIN:
|
||||
typtab = data.plain_events
|
||||
typparams['message'] = kwargs['message']
|
||||
case EventType.CHEER:
|
||||
typtab = data.cheer_events
|
||||
typparams['amount'] = kwargs['amount']
|
||||
typparams['cheer_type'] = kwargs.get('cheer_type')
|
||||
typparams['message'] = kwargs.get('message')
|
||||
case EventType.RAID:
|
||||
typtab = data.raid_events
|
||||
typparams['visitor_count'] = kwargs.get('viewer_count')
|
||||
case EventType.SUBSCRIBER:
|
||||
typtab = data.subscriber_events
|
||||
typparams['tier'] = kwargs['tier']
|
||||
typparams['subscribed_length'] = kwargs['subscribed_length']
|
||||
typparams['message'] = kwargs.get('message')
|
||||
case _:
|
||||
raise ValueError("Invalid EventType")
|
||||
|
||||
# TODO: This really really should be a transaction
|
||||
|
||||
# Create Document if required
|
||||
if 'document' in kwargs:
|
||||
from .documents import Document
|
||||
doc_args = kwargs['document']
|
||||
await Document.validate_create_params(doc_args)
|
||||
doc = await Document.create(app, **doc_args)
|
||||
document_id = doc.row.document_id
|
||||
params['document_id'] = document_id
|
||||
elif 'document_id' in kwargs:
|
||||
document_id = kwargs['document_id']
|
||||
params['document_id'] = document_id
|
||||
|
||||
# Create User if required
|
||||
if 'user' in kwargs:
|
||||
from .users import User
|
||||
user_args = kwargs['user']
|
||||
await User.validate_create_params(user_args)
|
||||
user = await User.create(app, **user_args)
|
||||
user_id = user.row.user_id
|
||||
|
||||
if 'user_id' in kwargs and not kwargs['user_id'] == user_id:
|
||||
raise web.HTTPBadRequest(text="Provided 'user_id' does not match provided 'user'.")
|
||||
else:
|
||||
user_id = kwargs['user_id']
|
||||
params['user_id'] = user_id
|
||||
|
||||
# Create Event row
|
||||
params['event_type'] = event_type
|
||||
params['user_name'] = kwargs['user_name']
|
||||
params['occurred_at'] = datetime.fromisoformat(kwargs['occurred_at'])
|
||||
|
||||
eventrow = await data.Events.create(**params)
|
||||
typparams['event_id'] = eventrow.event_id
|
||||
|
||||
# Create Event type row
|
||||
typrow = await typtab.insert(**typparams)
|
||||
|
||||
details = await data.EventDetails.fetch(eventrow.event_id)
|
||||
assert details is not None
|
||||
self = cls(app, details)
|
||||
await self.log_create()
|
||||
return self
|
||||
|
||||
async def log_create(self, with_image=True):
|
||||
args = await self.event_log_args(with_image=with_image)
|
||||
args.kwargs['embed'].title = f"Event #{self.row.event_id} Created!"
|
||||
try:
|
||||
await event_log(**args.send_args)
|
||||
except discord.HTTPException:
|
||||
if with_image:
|
||||
# Try again without the image in case that was the issue
|
||||
await self.log_create(with_image=False)
|
||||
|
||||
async def log_edit(self, with_image=True):
|
||||
args = await self.event_log_args(with_image=with_image)
|
||||
args.kwargs['embed'].title = f"Event #{self.row.event_id} Updated!"
|
||||
try:
|
||||
await event_log(**args.send_args)
|
||||
except discord.HTTPException:
|
||||
if with_image:
|
||||
# Try again without the image in case that was the issue
|
||||
await self.log_create(with_image=False)
|
||||
|
||||
async def event_log_args(self, with_image=True) -> MessageArgs:
|
||||
desc = '\n'.join(await self.tabulate())
|
||||
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
|
||||
embed.set_footer(text='Created At')
|
||||
args: dict = {'embed': embed}
|
||||
|
||||
doc = await self.get_document()
|
||||
if doc is not None:
|
||||
embed.add_field(
|
||||
name="Document",
|
||||
value='\n'.join(await doc.tabulate()),
|
||||
inline=False
|
||||
)
|
||||
if with_image:
|
||||
try:
|
||||
imagedata = doc.row.to_bytes()
|
||||
imagedata.seek(0)
|
||||
embed.set_image(url='attachment://document.png')
|
||||
args['files'] = [discord.File(imagedata, "document.png")]
|
||||
except binascii.Error:
|
||||
# Could not decode base64
|
||||
embed.add_field(name='Image', value="Could not decode document data!")
|
||||
else:
|
||||
embed.add_field(name='Image', value="Failed to send image!")
|
||||
return MessageArgs(**args)
|
||||
|
||||
async def tabulate(self):
|
||||
"""
|
||||
Present the Event as a discord-readable table.
|
||||
"""
|
||||
user = await self.get_user()
|
||||
assert user is not None
|
||||
|
||||
table = {
|
||||
'event_id': f"`{self.row.event_id}`",
|
||||
'event_type': f"`{self.row.event_type}`",
|
||||
'user': f"`{self.row.user_id}` (`{self.row.user_name}`)",
|
||||
'document': f"`{self.row.document_id}`",
|
||||
'occurred_at': discord.utils.format_dt(self.row.occurred_at, 'F'),
|
||||
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||
}
|
||||
info = self.row.event_type.info()
|
||||
for col, param in zip(info.detailcolumns, info.params):
|
||||
value = getattr(self.row, col)
|
||||
table[param] = f"`{value}`"
|
||||
return tabulate(*table.items())
|
||||
|
||||
async def edit(self, **kwargs):
|
||||
data = self.data
|
||||
# EventD = data.EventDetails
|
||||
|
||||
if 'event_type' in kwargs:
|
||||
raise web.HTTPBadRequest(text="You cannot change the type of an event after creation.")
|
||||
|
||||
typparams = {}
|
||||
|
||||
match self.row.event_type:
|
||||
case EventType.PLAIN:
|
||||
typtab = data.plain_events
|
||||
if 'message' in kwargs:
|
||||
typparams['message'] = kwargs['message']
|
||||
case EventType.CHEER:
|
||||
typtab = data.cheer_events
|
||||
for key in ('amount', 'cheer_type', 'message'):
|
||||
if key in kwargs:
|
||||
typparams[key] = kwargs[key]
|
||||
case EventType.RAID:
|
||||
typtab = data.raid_events
|
||||
if 'viewer_count' in kwargs:
|
||||
typparams['visitor_count'] = 'viewer_count'
|
||||
case EventType.SUBSCRIBER:
|
||||
typtab = data.subscriber_events
|
||||
for key in ('tier', 'subscribed_length', 'message'):
|
||||
if key in kwargs:
|
||||
typparams[key] = kwargs[key]
|
||||
if typparams:
|
||||
await typtab.update_where(event_id=self.row.event_id).set(**typparams)
|
||||
|
||||
await self.log_edit()
|
||||
await self.row.refresh()
|
||||
|
||||
async def delete(self):
|
||||
payload = await self.prepare()
|
||||
if self.row.document_id:
|
||||
await self.data.Document.table.delete_where(document_id=self.row.document_id)
|
||||
await self.data.Events.table.delete_where(event_id=self.row.event_id)
|
||||
await self.row.refresh()
|
||||
return payload
|
||||
|
||||
async def get_user(self):
|
||||
from .users import User
|
||||
return await User.fetch_from_id(self.app, self.row.user_id)
|
||||
|
||||
async def get_document(self):
|
||||
from .documents import Document
|
||||
if self.row.document_id:
|
||||
return await Document.fetch_from_id(self.app, self.row.document_id)
|
||||
|
||||
async def prepare(self):
|
||||
row = await self.row.refresh()
|
||||
assert row is not None
|
||||
data = self.data
|
||||
|
||||
user = await self.get_user()
|
||||
assert user is not None
|
||||
document = await self.get_document()
|
||||
|
||||
payload = {
|
||||
'event_id': self.row.event_id,
|
||||
'document_id': self.row.document_id,
|
||||
'document': await document.prepare() if document else None,
|
||||
'user_id': self.row.user_id,
|
||||
'user': await user.prepare(),
|
||||
'user_name': self.row.user_name,
|
||||
'occurred_at': self.row.occurred_at.isoformat(),
|
||||
'created_at': self.row.created_at.isoformat(),
|
||||
'event_type': self.row.event_type.value[0],
|
||||
}
|
||||
|
||||
match row.event_type:
|
||||
case EventType.PLAIN:
|
||||
payload['message'] = row.plain_message
|
||||
case EventType.SUBSCRIBER:
|
||||
payload['tier'] = row.subscriber_tier
|
||||
payload['subscribed_length'] = row.subscriber_length
|
||||
payload['message'] = row.subscriber_message
|
||||
case EventType.CHEER:
|
||||
payload['amount'] = row.cheer_amount
|
||||
payload['cheer_type'] = row.cheer_type
|
||||
payload['message'] = row.cheer_message
|
||||
case EventType.RAID:
|
||||
payload['viewer_count'] = row.raid_visitor_count
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@routes.view('/events')
|
||||
@routes.view('/events/', name='events')
|
||||
class EventsView(web.View):
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
if 'user_id' in request:
|
||||
params.setdefault('user_id', request['user_id'])
|
||||
|
||||
await Event.validate_create_params(params)
|
||||
logger.info(f"Creating a new event with args: {params=}")
|
||||
event = await Event.create(self.request.app, **params)
|
||||
logger.debug(f"Created event: {event!r}")
|
||||
payload = await event.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def get(self):
|
||||
request = self.request
|
||||
filter_params = {}
|
||||
keys = [
|
||||
'event_id', 'document_id', 'document_seal',
|
||||
'user_id', 'user_name', 'occurred_before', 'occurred_after',
|
||||
'created_before', 'created_after', 'event_type',
|
||||
]
|
||||
for key in keys:
|
||||
value = request.query.get(key, request.get(key, None))
|
||||
filter_params[key] = value
|
||||
|
||||
logger.info(f"Querying events with params: {filter_params=}")
|
||||
events = await Event.query(request.app, **filter_params)
|
||||
payload = [await event.prepare() for event in events]
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.view('/events/{event_id}')
|
||||
@routes.view('/events/{event_id}/', name='event')
|
||||
class EventView(web.View):
|
||||
async def resolve_event(self):
|
||||
request = self.request
|
||||
event_id = request.match_info['event_id']
|
||||
event = await Event.fetch_from_id(request.app, int(event_id))
|
||||
if event is None:
|
||||
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||
return event
|
||||
|
||||
async def get(self):
|
||||
event = await self.resolve_event()
|
||||
logger.info(f"Received GET for event {event=}")
|
||||
payload = await event.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
event = await self.resolve_event()
|
||||
params = await self.request.json()
|
||||
|
||||
edit_data = {}
|
||||
edit_fields = {'message', 'amount', 'cheer_type', 'viewer_count', 'tier', 'subscriber_length', 'message'}
|
||||
for key, value in params.items():
|
||||
if key not in edit_fields:
|
||||
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
|
||||
edit_data[key] = value
|
||||
|
||||
for key in edit_fields:
|
||||
if key in self.request:
|
||||
edit_data.setdefault(key, self.request[key])
|
||||
|
||||
logger.info(f"Received PATCH for event {event} with params: {params}")
|
||||
await event.edit(**edit_data)
|
||||
payload = await event.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def delete(self):
|
||||
event = await self.resolve_event()
|
||||
logger.info(f"Received DELETE for event {event}")
|
||||
payload = await event.delete()
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.route('*', "/events/{event_id}/user")
|
||||
@routes.route('*', "/events/{event_id}/user{tail:/.*}")
|
||||
async def event_user_route(request: web.Request):
|
||||
event_id = int(request.match_info['event_id'])
|
||||
event = await Event.fetch_from_id(request.app, event_id)
|
||||
if event is None:
|
||||
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||
|
||||
tail = request.match_info.get('tail', '')
|
||||
new_path = "/users/{user_id}".format(user_id=event.row.user_id) + tail
|
||||
|
||||
logger.info(f"Redirecting {request=} to {new_path}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = event.row.user_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
|
||||
|
||||
@routes.route('*', "/events/{event_id}/document")
|
||||
@routes.route('*', "/events/{event_id}/document{tail:/.*}")
|
||||
async def event_document_route(request: web.Request):
|
||||
event_id = int(request.match_info['event_id'])
|
||||
event = await Event.fetch_from_id(request.app, event_id)
|
||||
if event is None:
|
||||
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||
|
||||
tail = request.match_info.get('tail', '')
|
||||
|
||||
document = await event.get_document()
|
||||
if document is None:
|
||||
if request.method == 'POST' and not tail:
|
||||
new_path = '/documents'
|
||||
logger.info(f"Redirecting {request=} to POST /documents")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['event_id'] = event_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
raise web.HTTPNotFound(text="This event has no document.")
|
||||
else:
|
||||
document_id = document.row.document_id
|
||||
# Redirect to POST /documents/{document_id}/...
|
||||
new_path = f"/documents/{document_id}".format(document_id=document_id) + tail
|
||||
logger.info(f"Redirecting {request=} to {new_path}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['event_id'] = event_id
|
||||
new_request['document_id'] = document_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
|
||||
29
src/routes/lib.py
Normal file
29
src/routes/lib.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import NamedTuple, Any, Optional, Self, Unpack, List, TypedDict
|
||||
from aiohttp import web, ClientSession
|
||||
from discord import Webhook
|
||||
|
||||
from data.database import Database
|
||||
from datamodels import DataModel
|
||||
from modules.profiles.data import ProfileData
|
||||
from meta import conf
|
||||
|
||||
dbvar = web.AppKey("database", Database)
|
||||
datamodelsv = web.AppKey("datamodels", DataModel)
|
||||
profiledatav = web.AppKey("profiledata", ProfileData)
|
||||
|
||||
|
||||
class ModelField(NamedTuple):
|
||||
name: str
|
||||
typ: Any
|
||||
required: bool
|
||||
can_create: bool
|
||||
can_edit: bool
|
||||
|
||||
|
||||
async def event_log(*args, **kwargs):
|
||||
# Post the given message to the configured event log, if set
|
||||
event_log_url = conf.api.get('EVENTLOG')
|
||||
if event_log_url:
|
||||
async with ClientSession() as session:
|
||||
webhook = Webhook.from_url(event_log_url, session=session)
|
||||
await webhook.send(**kwargs)
|
||||
304
src/routes/specimens.py
Normal file
304
src/routes/specimens.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
- `/specimens` with `GET` and `POST`
|
||||
- `/specimens/{specimen_id}` with `PATCH` and `DELETE`
|
||||
- `/specimens/{specimen_id}/owner` which is passed to `/users/{user_id}`
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
|
||||
from aiohttp import web
|
||||
from data import Condition, condition
|
||||
from data.conditions import NULL
|
||||
from data.queries import JOINTYPE
|
||||
from datamodels import DataModel
|
||||
|
||||
from .lib import ModelField, datamodelsv, dbvar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .users import UserCreateParams, UserPayload, User
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecimenPayload(TypedDict):
|
||||
specimen_id: int
|
||||
owner_id: int
|
||||
owner: 'UserPayload'
|
||||
born_at: str
|
||||
forgotten_at: Optional[str]
|
||||
|
||||
|
||||
class SpecimenCreateParamsReq(TypedDict, total=True):
|
||||
owner_id: int
|
||||
|
||||
|
||||
class SpecimenCreateParams(SpecimenCreateParamsReq, total=False):
|
||||
owner: 'UserCreateParams'
|
||||
born_at: str
|
||||
forgotten_at: str
|
||||
|
||||
|
||||
class SpecimenEditParams(TypedDict, total=False):
|
||||
owner_id: int
|
||||
forgotten_at: Optional[str]
|
||||
|
||||
fields = [
|
||||
ModelField('specimen_id', int, False, False, False),
|
||||
ModelField('owner_id', int, False, True, True),
|
||||
ModelField('owner', 'UserPayload', False, True, False),
|
||||
ModelField('born_at', str, False, True, False),
|
||||
ModelField('forgotten_at', str, False, True, True),
|
||||
]
|
||||
req_fields = {field.name for field in fields if field.required}
|
||||
edit_fields = {field.name for field in fields if field.can_edit}
|
||||
create_fields = {field.name for field in fields if field.can_create}
|
||||
|
||||
|
||||
class Specimen:
|
||||
def __init__(self, app: web.Application, row: DataModel.Specimen):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
async def validate_create_params(cls, params):
|
||||
if 'owner_id' not in params and 'owner' not in params:
|
||||
raise web.HTTPBadRequest(text="One of 'owner' or 'owner_id' must be supplied to create Specimen.")
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to specimen creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"Specimen params missing required key '{missing}'")
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, spec_id: int) -> Optional[Self]:
|
||||
data = app[datamodelsv]
|
||||
row = await data.Specimen.fetch(int(spec_id))
|
||||
return cls(app, row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
specimen_id: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
born_after: Optional[str] = None,
|
||||
born_before: Optional[str] = None,
|
||||
forgotten: Optional[str] = None,
|
||||
forgotten_after: Optional[str] = None,
|
||||
forgotten_before: Optional[str] = None,
|
||||
) -> List[Self]:
|
||||
data = app[datamodelsv]
|
||||
Spec = data.Specimen
|
||||
|
||||
conds = []
|
||||
|
||||
if specimen_id is not None:
|
||||
conds.append(Spec.specimen_id == int(specimen_id))
|
||||
if owner_id is not None:
|
||||
conds.append(Spec.owner_id == int(owner_id))
|
||||
if born_after is not None:
|
||||
bafter = datetime.fromisoformat(born_after)
|
||||
conds.append(Spec.born_at >= bafter)
|
||||
if born_before is not None:
|
||||
bbefore = datetime.fromisoformat(born_before)
|
||||
conds.append(Spec.born_at <= bbefore)
|
||||
if forgotten_after is not None:
|
||||
fafter = datetime.fromisoformat(forgotten_after)
|
||||
conds.append(Spec.forgotten_at >= fafter)
|
||||
if forgotten_before is not None:
|
||||
fbefore = datetime.fromisoformat(forgotten_before)
|
||||
conds.append(Spec.forgotten_at <= fbefore)
|
||||
if forgotten is not None:
|
||||
if forgotten.lower() in ('1', 'true'):
|
||||
conds.append(Spec.forgotten_at != NULL)
|
||||
elif forgotten.lower() in ('0', 'false'):
|
||||
conds.append(Spec.forgotten_at == NULL)
|
||||
rows = await Spec.fetch_where(*conds).order_by(Spec.born_at)
|
||||
return [cls(app, row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def create(cls, app: web.Application, **kwargs: Unpack[SpecimenCreateParams]) -> Self:
|
||||
"""
|
||||
Create a new specimen from the given data.
|
||||
|
||||
This will create the provided 'owner' if required.
|
||||
"""
|
||||
from .users import User
|
||||
|
||||
create_args = {}
|
||||
|
||||
if 'owner' in kwargs:
|
||||
# Create owner and set owner_id
|
||||
owner_args = kwargs['owner']
|
||||
await User.validate_create_params(owner_args)
|
||||
owner = await User.create(app, **owner_args)
|
||||
owner_id = owner.row.user_id
|
||||
|
||||
if 'owner_id' in kwargs and not kwargs['owner_id'] == owner_id:
|
||||
raise web.HTTPBadRequest(text="Provided `owner_id` does not match provided `owner`.")
|
||||
else:
|
||||
owner_id = int(kwargs['owner_id'])
|
||||
create_args['owner_id'] = owner_id
|
||||
|
||||
if 'born_at' in kwargs:
|
||||
create_args['born_at'] = datetime.fromisoformat(kwargs['born_at'])
|
||||
if 'forgotten_at' in kwargs:
|
||||
create_args['forgotten_at'] = datetime.fromisoformat(kwargs['forgotten_at'])
|
||||
|
||||
data = app[datamodelsv]
|
||||
|
||||
logger.info(f"Creating Specimen with {create_args=}")
|
||||
|
||||
row = await data.Specimen.create(**create_args)
|
||||
return cls(app, row)
|
||||
|
||||
async def edit(self, **kwargs: Unpack[SpecimenEditParams]):
|
||||
row = self.row
|
||||
|
||||
edit_args = {}
|
||||
if 'owner_id' in kwargs:
|
||||
edit_args['owner_id'] = kwargs['owner_id']
|
||||
# TODO: We should probably check that the specified owner exists
|
||||
if 'forgotten_at' in kwargs:
|
||||
forg = kwargs['forgotten_at']
|
||||
if forg is None:
|
||||
# Allows unsetting the forgotten date
|
||||
# This may error if the user already had a live specimen
|
||||
edit_args['forgotten_at'] = None
|
||||
else:
|
||||
edit_args['forgotten_at'] = datetime.fromisoformat(forg)
|
||||
|
||||
if edit_args:
|
||||
logger.info(f"Updating specimen {row=} with {kwargs}")
|
||||
await row.update(**edit_args)
|
||||
|
||||
async def delete(self) -> SpecimenPayload:
|
||||
payload = await self.prepare()
|
||||
await self.row.delete()
|
||||
return payload
|
||||
|
||||
async def get_owner(self):
|
||||
from .users import User
|
||||
return await User.fetch_from_id(self.app, self.row.owner_id)
|
||||
|
||||
async def prepare(self) -> SpecimenPayload:
|
||||
owner = await self.get_owner()
|
||||
if owner is None:
|
||||
raise ValueError("Specimen Owner does not exist! This should never happen!")
|
||||
|
||||
results: SpecimenPayload = {
|
||||
'specimen_id': self.row.specimen_id,
|
||||
'owner_id': self.row.owner_id,
|
||||
'owner': await owner.prepare(),
|
||||
'born_at': self.row.born_at.isoformat(),
|
||||
'forgotten_at': self.row.forgotten_at.isoformat() if self.row.forgotten_at else None
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
@routes.view('/specimens')
|
||||
@routes.view('/specimens/', name='specimens')
|
||||
class SpecimensView(web.View):
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
for key in create_fields:
|
||||
if key in request:
|
||||
params.setdefault(key, request[key])
|
||||
|
||||
await Specimen.validate_create_params(params)
|
||||
logger.info(f"Creating a new Specimen with args: {params=}")
|
||||
spec = await Specimen.create(self.request.app, **params)
|
||||
logger.debug(f"Created specimen: {spec!r}")
|
||||
|
||||
payload = await spec.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def get(self):
|
||||
request = self.request
|
||||
filter_params = {}
|
||||
keys = [
|
||||
'specimen_id', 'owner_id', 'born_after', 'born_before',
|
||||
'forgotten', 'forgotten_after', 'forgotten_before'
|
||||
]
|
||||
for key in keys:
|
||||
value = request.query.get(key, request.get(key, None))
|
||||
filter_params[key] = value
|
||||
|
||||
logger.info(f"Querying specimens with params: {filter_params=}")
|
||||
specs = await Specimen.query(request.app, **filter_params)
|
||||
payload = [await spec.prepare() for spec in specs]
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.view('/specimens/{specimen_id}')
|
||||
@routes.view('/specimens/{specimen_id}/', name='specimen')
|
||||
class SpecimenView(web.View):
|
||||
async def resolve_specimen(self):
|
||||
request = self.request
|
||||
spec_id = request.match_info['specimen_id']
|
||||
spec = await Specimen.fetch_from_id(request.app, int(spec_id))
|
||||
|
||||
if spec is None:
|
||||
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
|
||||
|
||||
return spec
|
||||
|
||||
async def get(self):
|
||||
spec = await self.resolve_specimen()
|
||||
logger.info(f"Received GET for specimen {spec=}")
|
||||
payload = await spec.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
spec = await self.resolve_specimen()
|
||||
params = await self.request.json()
|
||||
|
||||
edit_data = {}
|
||||
for key, value in params.items():
|
||||
if key not in edit_fields:
|
||||
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Specimen!")
|
||||
edit_data[key] = value
|
||||
|
||||
for key in edit_fields:
|
||||
if key in self.request:
|
||||
edit_data.setdefault(key, self.request[key])
|
||||
|
||||
logger.info(f"Received PATCH for specimen {spec} with params: {params}")
|
||||
await spec.edit(**edit_data)
|
||||
payload = await spec.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def delete(self):
|
||||
spec = await self.resolve_specimen()
|
||||
logger.info(f"Received DELETE for specimen {spec}")
|
||||
payload = await spec.delete()
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
|
||||
@routes.route('*', "/specimens/{specimen_id}/owner")
|
||||
@routes.route('*', "/specimens/{specimen_id}/owner{tail:/.*}")
|
||||
async def specimen_owner_route(request: web.Request):
|
||||
spec_id = int(request.match_info['specimen_id'])
|
||||
spec = await Specimen.fetch_from_id(request.app, spec_id)
|
||||
if spec is None:
|
||||
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
|
||||
|
||||
tail = request.match_info.get('tail', '')
|
||||
new_path = "/users/{user_id}".format(user_id=spec.row.owner_id) + tail
|
||||
|
||||
logger.info(f"Redirecting {request=} to {new_path}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = spec.row.owner_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
262
src/routes/stamps.py
Normal file
262
src/routes/stamps.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from typing import Any, NamedTuple, Optional, TypedDict, Unpack, reveal_type
|
||||
from aiohttp import web
|
||||
from data.database import Database
|
||||
from datamodels import DataModel
|
||||
|
||||
from .lib import datamodelsv, ModelField
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
|
||||
class StampPayload(TypedDict):
|
||||
stamp_id: int
|
||||
document_id: int
|
||||
stamp_type: str
|
||||
pos_x: int
|
||||
pos_y: int
|
||||
rotation: float
|
||||
|
||||
|
||||
class StampCreateParamsReq(TypedDict, total=True):
|
||||
document_id: int
|
||||
stamp_type: str
|
||||
pos_x: int
|
||||
pos_y: int
|
||||
rotation: float
|
||||
|
||||
|
||||
class StampCreateParams(StampCreateParamsReq, total=False):
|
||||
pass
|
||||
|
||||
|
||||
class StampEditParams(TypedDict, total=False):
|
||||
document_id: int
|
||||
stamp_type: str
|
||||
pos_x: int
|
||||
pos_y: int
|
||||
rotation: float
|
||||
|
||||
|
||||
fields = [
|
||||
ModelField('stamp_id', int, False, False, False),
|
||||
ModelField('document_id', int, True, True, True),
|
||||
ModelField('stamp_type', int, True, True, True),
|
||||
ModelField('pos_x', int, True, True, True),
|
||||
ModelField('pos_y', int, True, True, True),
|
||||
ModelField('rotation', float, True, True, True),
|
||||
]
|
||||
req_fields = {field.name for field in fields if field.required}
|
||||
edit_fields = {field.name for field in fields if field.can_edit}
|
||||
create_fields = {field.name for field in fields if field.can_create}
|
||||
|
||||
|
||||
|
||||
class Stamp:
|
||||
def __init__(self, app: web.Application, row: DataModel.DocumentStamp):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, stamp_id: int) -> Optional['Stamp']:
|
||||
stamp = await app[datamodelsv].DocumentStamp.fetch(stamp_id)
|
||||
if stamp is None:
|
||||
return None
|
||||
return cls(app, stamp)
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
stamp_id: Optional[int] = None,
|
||||
document_id: Optional[int] = None,
|
||||
stamp_type: Optional[str] = None,
|
||||
):
|
||||
data = app[datamodelsv]
|
||||
|
||||
query_args = {}
|
||||
if stamp_id is not None:
|
||||
query_args['stamp_id'] = int(stamp_id)
|
||||
if document_id is not None:
|
||||
query_args['document_id'] = int(document_id)
|
||||
if stamp_type is not None:
|
||||
typerows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||
typeids = [row.stamp_type_id for row in typerows]
|
||||
if not typeids:
|
||||
return []
|
||||
query_args['stamp_type'] = typeids
|
||||
results = await data.DocumentStamp.table.fetch_rows_where(**query_args)
|
||||
return [cls(app, row) for row in sorted(results, key=lambda row:row.stamp_id)]
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
app: web.Application,
|
||||
**kwargs: Unpack[StampCreateParams]
|
||||
):
|
||||
data = app[datamodelsv]
|
||||
stamp_type = kwargs['stamp_type']
|
||||
# Get the stamp_type
|
||||
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||
if not rows:
|
||||
# Create the stamp type
|
||||
row = await data.StampType.create(stamp_type_name=stamp_type)
|
||||
else:
|
||||
row = rows[0]
|
||||
|
||||
stamprow = await data.DocumentStamp.create(
|
||||
document_id=kwargs['document_id'],
|
||||
stamp_type=row.stamp_type_id,
|
||||
position_x=int(kwargs['pos_x']),
|
||||
position_y=int(kwargs['pos_y']),
|
||||
rotation=float(kwargs['rotation'])
|
||||
)
|
||||
return cls(app, stamprow)
|
||||
|
||||
async def prepare(self) -> StampPayload:
|
||||
typerow = await self.data.StampType.fetch(self.row.stamp_type)
|
||||
assert typerow is not None
|
||||
|
||||
results: StampPayload = {
|
||||
'stamp_id': self.row.stamp_id,
|
||||
'document_id': self.row.document_id,
|
||||
'stamp_type': typerow.stamp_type_name,
|
||||
'pos_x': self.row.position_x,
|
||||
'pos_y': self.row.position_y,
|
||||
'rotation': self.row.rotation,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
async def edit(
|
||||
self,
|
||||
**kwargs: Unpack[StampEditParams]
|
||||
):
|
||||
data = self.data
|
||||
row = self.row
|
||||
edit_args = {}
|
||||
if stamp_type := kwargs.get('stamp_type'):
|
||||
# Get the stamp_type
|
||||
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||
if not rows:
|
||||
# Create the stamp type
|
||||
row = await data.StampType.create(stamp_type_name=stamp_type)
|
||||
else:
|
||||
row = rows[0]
|
||||
edit_args['stamp_type'] = row.stamp_type_id
|
||||
simple_keys = {
|
||||
'document_id': 'document_id',
|
||||
'pos_x': 'position_x',
|
||||
'pos_y': 'position_y',
|
||||
'rotation': 'rotation'
|
||||
}
|
||||
for editkey, datakey in simple_keys.items():
|
||||
if editkey in kwargs:
|
||||
edit_args[datakey] = kwargs[editkey]
|
||||
|
||||
await self.row.update(
|
||||
**edit_args
|
||||
)
|
||||
|
||||
async def delete(self) -> StampPayload:
|
||||
payload = await self.prepare()
|
||||
await self.row.delete()
|
||||
return payload
|
||||
|
||||
|
||||
@routes.view('/stamps')
|
||||
@routes.view('/stamps/', name='stamps')
|
||||
class StampsView(web.View):
|
||||
async def get(self):
|
||||
request = self.request
|
||||
# Decode request parameters to filter args
|
||||
filter_params = {}
|
||||
|
||||
keys = ['stamp_id', 'document_id', 'stamp_type']
|
||||
for key in keys:
|
||||
if key in request.query:
|
||||
filter_params[key] = request.query[key]
|
||||
elif key in request:
|
||||
filter_params[key] = request[key]
|
||||
|
||||
stamps = await Stamp.query(request.app, **filter_params)
|
||||
payload = [await stamp.prepare() for stamp in stamps]
|
||||
|
||||
return web.json_response(payload)
|
||||
|
||||
async def create_one(self, params: StampCreateParams):
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to stamp creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"Stamp params missing required key '{missing}'.")
|
||||
|
||||
# This still doesn't guarantee that the values are of the correct type, but good enough.
|
||||
stamp = await Stamp.create(self.request.app, **params)
|
||||
return stamp
|
||||
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
for key in create_fields:
|
||||
if key in request:
|
||||
params.setdefault(key, request[key])
|
||||
|
||||
stamp = await self.create_one(params)
|
||||
payload = await stamp.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def put(self):
|
||||
request = self.request
|
||||
|
||||
from_request = {key: request[key] for key in create_fields if key in request}
|
||||
argslist = await request.json()
|
||||
|
||||
payloads = []
|
||||
for args in argslist:
|
||||
stamp = await self.create_one(from_request | args)
|
||||
payload = await stamp.prepare()
|
||||
payloads.append(payload)
|
||||
|
||||
return web.json_response(payloads)
|
||||
|
||||
|
||||
@routes.view('/stamps/{stamp_id}')
|
||||
@routes.view('/stamps/{stamp_id}/', name='stamp')
|
||||
class StampView(web.View):
|
||||
|
||||
async def resolve_stamp(self):
|
||||
request = self.request
|
||||
stamp_id = request.match_info['stamp_id']
|
||||
stamp = await Stamp.fetch_from_id(request.app, int(stamp_id))
|
||||
if stamp is None:
|
||||
raise web.HTTPNotFound(text="No stamp exists with the given ID.")
|
||||
return stamp
|
||||
|
||||
async def get(self):
|
||||
stamp = await self.resolve_stamp()
|
||||
payload = await stamp.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
stamp = await self.resolve_stamp()
|
||||
params = await self.request.json()
|
||||
|
||||
edit_data = {}
|
||||
for key, value in params.items():
|
||||
if key not in edit_fields:
|
||||
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Stamp!")
|
||||
edit_data[key] = value
|
||||
|
||||
for key in edit_fields:
|
||||
if key in self.request:
|
||||
edit_data.setdefault(key, self.request[key])
|
||||
|
||||
await stamp.edit(**edit_data)
|
||||
payload = await stamp.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def delete(self):
|
||||
stamp = await self.resolve_stamp()
|
||||
payload = await stamp.delete()
|
||||
return web.json_response(payload)
|
||||
223
src/routes/transactions.py
Normal file
223
src/routes/transactions.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
|
||||
from aiohttp import web
|
||||
from data import Condition, condition
|
||||
from data.conditions import NULL
|
||||
from data.queries import JOINTYPE
|
||||
from datamodels import DataModel
|
||||
|
||||
from .lib import ModelField, datamodelsv, dbvar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .users import UserCreateParams, UserPayload, User
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransactionPayload(TypedDict):
|
||||
transaction_id: int
|
||||
user_id: int
|
||||
user: 'UserPayload'
|
||||
amount: int
|
||||
description: str
|
||||
reference: Optional[str]
|
||||
created_at: str
|
||||
|
||||
class TransactionCreateParamsReq(TypedDict, total=True):
|
||||
user_id: int
|
||||
amount: int
|
||||
description: str
|
||||
|
||||
class TransactionCreateParams(TransactionCreateParamsReq, total=False):
|
||||
reference: str
|
||||
|
||||
fields = [
|
||||
ModelField('transaction_id', int, False, False, False),
|
||||
ModelField('user_id', int, True, True, False),
|
||||
ModelField('amount', int, True, True, False),
|
||||
ModelField('description', str, True, True, False),
|
||||
ModelField('reference', str, False, True, False),
|
||||
ModelField('created_at', str, False, False, False),
|
||||
]
|
||||
req_fields = {field.name for field in fields if field.required}
|
||||
edit_fields = {field.name for field in fields if field.can_edit}
|
||||
create_fields = {field.name for field in fields if field.can_create}
|
||||
|
||||
|
||||
class Transaction:
|
||||
def __init__(self, app: web.Application, row: DataModel.Transaction):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
async def validate_create_params(cls, params):
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to transaction creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"Transaction params missing required key '{missing}'")
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, tid: int) -> Optional[Self]:
|
||||
data = app[datamodelsv]
|
||||
row = await data.Transaction.fetch(int(tid))
|
||||
return cls(app, row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
transaction_id: Optional[int] = None,
|
||||
user_id: Optional[int] = None,
|
||||
reference: Optional[str] = None,
|
||||
created_before: Optional[str] = None,
|
||||
created_after: Optional[str] = None,
|
||||
) -> List[Self]:
|
||||
data = app[datamodelsv]
|
||||
TXN = data.Transaction
|
||||
|
||||
conds = []
|
||||
|
||||
if transaction_id is not None:
|
||||
conds.append(TXN.transaction_id == int(transaction_id))
|
||||
if user_id is not None:
|
||||
conds.append(TXN.user_id == int(user_id))
|
||||
if reference is not None:
|
||||
conds.append(TXN.reference == reference)
|
||||
if created_before is not None:
|
||||
cbefore = datetime.fromisoformat(created_before)
|
||||
conds.append(TXN.created_at <= cbefore)
|
||||
if created_after is not None:
|
||||
cafter = datetime.fromisoformat(created_after)
|
||||
conds.append(TXN.created_at >= cafter)
|
||||
|
||||
rows = await TXN.fetch_where(*conds).order_by(TXN.created_at)
|
||||
return [cls(app, row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def create(cls, app: web.Application, **kwargs: Unpack[TransactionCreateParams]) -> Self:
|
||||
data = app[datamodelsv]
|
||||
|
||||
create_args = {}
|
||||
|
||||
for key in ('user_id', 'description', 'amount', 'reference'):
|
||||
create_args[key] = kwargs.get(key)
|
||||
|
||||
logger.info(f"Creating Transaction with {create_args=}")
|
||||
row = await data.Transaction.create(**create_args)
|
||||
return cls(app, row)
|
||||
|
||||
async def edit(self):
|
||||
raise ValueError("Transactions are immutable.")
|
||||
|
||||
async def delete(self):
|
||||
raise ValueError("Transactions cannot be deleted directly.")
|
||||
|
||||
async def get_user(self):
|
||||
from .users import User
|
||||
return await User.fetch_from_id(self.app, self.row.user_id)
|
||||
|
||||
async def prepare(self) -> TransactionPayload:
|
||||
user = await self.get_user()
|
||||
if user is None:
|
||||
raise ValueError("Transaction owner does not exist! This cannot happen.")
|
||||
|
||||
results: TransactionPayload = {
|
||||
'transaction_id': self.row.transaction_id,
|
||||
'user_id': self.row.user_id,
|
||||
'user': await user.prepare(details=False),
|
||||
'amount': self.row.amount,
|
||||
'description': self.row.description,
|
||||
'reference': self.row.reference,
|
||||
'created_at': self.row.created_at.isoformat(),
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
@routes.view('/transactions')
|
||||
@routes.view('/transactions/', name='transactions')
|
||||
class TransactionsView(web.View):
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
for key in create_fields:
|
||||
if key in request:
|
||||
params.setdefault(key, request[key])
|
||||
|
||||
await Transaction.validate_create_params(params)
|
||||
logger.info(f"Creating a new Transaction with args: {params=}")
|
||||
txn = await Transaction.create(self.request.app, **params)
|
||||
logger.debug(f"Created transaction: {txn!r}")
|
||||
|
||||
payload = await txn.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def get(self):
|
||||
request = self.request
|
||||
filter_params = {}
|
||||
keys = [
|
||||
'transaction_id', 'user_id', 'reference',
|
||||
'created_before', 'created_after',
|
||||
]
|
||||
for key in keys:
|
||||
value = request.query.get(key, request.get(key, None))
|
||||
filter_params[key] = value
|
||||
|
||||
logger.info(f"Querying transactions with params: {filter_params=}")
|
||||
txns = await Transaction.query(request.app, **filter_params)
|
||||
payload = [await txn.prepare() for txn in txns]
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.view('/transactions/{transaction_id}')
|
||||
@routes.view('/transactions/{transaction_id}/', name='transaction_id')
|
||||
class TransactionView(web.View):
|
||||
async def resolve_transaction(self):
|
||||
request = self.request
|
||||
txn_id = request.match_info['transaction_id']
|
||||
txn = await Transaction.fetch_from_id(request.app, int(txn_id))
|
||||
|
||||
if txn is None:
|
||||
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
|
||||
|
||||
return txn
|
||||
|
||||
async def get(self):
|
||||
txn = await self.resolve_transaction()
|
||||
logger.info(f"Received GET for transaction {txn=}")
|
||||
payload = await txn.prepare()
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
raise web.HTTPBadRequest(text="Transactions are immutable and cannot be edited.")
|
||||
|
||||
async def delete(self):
|
||||
raise web.HTTPBadRequest(text="Transactions cannot be individually deleted.")
|
||||
|
||||
|
||||
@routes.route('*', "/transactions/{transaction_id}/user")
|
||||
@routes.route('*', "/transactions/{transaction_id}/user{tail:/.*}")
|
||||
async def transaction_user_route(request: web.Request):
|
||||
txn_id = int(request.match_info['transaction_id'])
|
||||
txn = await Transaction.fetch_from_id(request.app, txn_id)
|
||||
if txn is None:
|
||||
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
|
||||
|
||||
tail = request.match_info.get('tail', '')
|
||||
new_path = "/users/{user_id}".format(user_id=txn.row.user_id) + tail
|
||||
|
||||
logger.info(f"Redirecting {request=} to {new_path}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = txn.row.user_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
461
src/routes/users.py
Normal file
461
src/routes/users.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
- `/users` with `POST`, `GET`, `PATCH`, `DELETE`
|
||||
- `/users/{user_id}` with `GET`, `PATCH`, `DELETE`
|
||||
- `/users/{user_id}/events` which is passed to `/events`
|
||||
- `/users/{user_id}/specimen` which is passed to `/specimens/{specimen_id}`
|
||||
- `/users/{user_id}/specimens` which is passed to `/specimens`
|
||||
- `/users/{user_id}/wallet` with `GET`
|
||||
- `/users/{user_id}/transactions` which is passed to `/transactions`
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
|
||||
from aiohttp import web
|
||||
import discord
|
||||
from data import Condition, condition
|
||||
from data.conditions import NULL
|
||||
from data.queries import JOINTYPE
|
||||
from datamodels import DataModel
|
||||
|
||||
from modules.profiles.data import ProfileData
|
||||
from utils.lib import MessageArgs, tabulate
|
||||
|
||||
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
|
||||
from .specimens import Specimen, SpecimenPayload
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserPayload(TypedDict):
|
||||
user_id: int
|
||||
twitch_id: Optional[str]
|
||||
name: Optional[str]
|
||||
preferences: Optional[str]
|
||||
created_at: str
|
||||
|
||||
|
||||
class UserDetailsPayload(UserPayload):
|
||||
specimen: Optional[SpecimenPayload]
|
||||
inventory: List # TODO
|
||||
wallet: int
|
||||
|
||||
|
||||
class UserCreateParamsReq(TypedDict, total=True):
|
||||
twitch_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserCreateParams(UserCreateParamsReq, total=False):
|
||||
preferences: str
|
||||
|
||||
|
||||
class UserEditParams(TypedDict, total=False):
|
||||
name: Optional[str]
|
||||
preferences: Optional[str]
|
||||
|
||||
|
||||
fields = [
|
||||
ModelField('user_id', int, False, False, False),
|
||||
ModelField('twitch_id', str, True, True, False),
|
||||
ModelField('name', str, True, True, True),
|
||||
ModelField('preferences', str, False, True, True),
|
||||
ModelField('created_at', str, False, False, False),
|
||||
]
|
||||
req_fields = {field.name for field in fields if field.required}
|
||||
edit_fields = {field.name for field in fields if field.can_edit}
|
||||
create_fields = {field.name for field in fields if field.can_create}
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, app: web.Application, row: DataModel.Dreamer):
|
||||
self.app = app
|
||||
self.data = app[datamodelsv]
|
||||
self.profile_data = app[profiledatav]
|
||||
|
||||
self.row = row
|
||||
self._pref_row: Optional[DataModel.UserPreferences] = None
|
||||
|
||||
async def get_prefs(self) -> DataModel.UserPreferences:
|
||||
if self._pref_row is None:
|
||||
self._pref_row = await self.data.UserPreferences.fetch_or_create(self.row.user_id)
|
||||
return self._pref_row
|
||||
|
||||
@classmethod
|
||||
async def validate_create_params(cls, params):
|
||||
if extra := next((key for key in params if key not in create_fields), None):
|
||||
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to user creation.")
|
||||
if missing := next((key for key in req_fields if key not in params), None):
|
||||
raise web.HTTPBadRequest(text=f"User params missing required key '{missing}'")
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_id(cls, app: web.Application, user_id: int):
|
||||
data = app[datamodelsv]
|
||||
row = await data.Dreamer.fetch(int(user_id))
|
||||
return cls(app, row) if row is not None else None
|
||||
|
||||
@classmethod
|
||||
async def query(
|
||||
cls,
|
||||
app: web.Application,
|
||||
user_id: Optional[str] = None,
|
||||
twitch_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
created_before: Optional[str] = None,
|
||||
created_after: Optional[str] = None,
|
||||
) -> List[Self]:
|
||||
data = app[datamodelsv]
|
||||
Dreamer = data.Dreamer
|
||||
|
||||
conds = []
|
||||
if user_id is not None:
|
||||
conds.append(Dreamer.user_id == int(user_id))
|
||||
if twitch_id is not None:
|
||||
conds.append(Dreamer.twitch_id == twitch_id)
|
||||
if name is not None:
|
||||
conds.append(Dreamer.name == name)
|
||||
if created_before is not None:
|
||||
cbefore = datetime.fromisoformat(created_before)
|
||||
conds.append(Dreamer.created_at <= cbefore)
|
||||
if created_after is not None:
|
||||
cafter = datetime.fromisoformat(created_after)
|
||||
conds.append(Dreamer.created_at >= cafter)
|
||||
|
||||
rows = await Dreamer.fetch_where(*conds).order_by(Dreamer.created_at)
|
||||
return [cls(app, row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def create(cls, app: web.Application, **kwargs: Unpack[UserCreateParams]):
|
||||
"""
|
||||
Create a new User from the provided data.
|
||||
|
||||
This creates the associated UserProfile, TwitchProfile, and UserPreferences if needed.
|
||||
If a profile already exists, this does *not* error.
|
||||
Instead, this updates the existing User with the new data.
|
||||
"""
|
||||
data = app[datamodelsv]
|
||||
|
||||
twitch_id = kwargs['twitch_id']
|
||||
name = kwargs['name']
|
||||
prefs = kwargs.get('preferences')
|
||||
|
||||
# Quick sanity check on the twitch id
|
||||
if not twitch_id or not twitch_id.isdigit():
|
||||
raise web.HTTPBadRequest(text="Invalid 'twitch_id' passed to user creation!")
|
||||
|
||||
# First check if the profile already exists by querying the Dreamer database
|
||||
edited = 0 # 0 means not edited, 1 means created, 2 means modified
|
||||
rows = await data.Dreamer.fetch_where(twitch_id=twitch_id)
|
||||
if rows:
|
||||
logger.debug(f"Updating Dreamer for {twitch_id=} with {kwargs}")
|
||||
dreamer = rows[0]
|
||||
# A twitch profile with this twitch_id already exists
|
||||
# But it is possible UserPreferences don't exist
|
||||
if dreamer.preferences is None and dreamer.name is None:
|
||||
await data.UserPreferences.fetch_or_create(dreamer.user_id, twitch_name=name, preferences=prefs)
|
||||
dreamer = await dreamer.refresh()
|
||||
edited = 2
|
||||
|
||||
# Now compare the existing data against the provided data and update if needed
|
||||
if name != dreamer.name:
|
||||
q = data.UserPreferences.table.update_where(profileid=dreamer.user_id)
|
||||
q.set(twitch_name=name)
|
||||
if prefs is not None:
|
||||
q.set(preferences=prefs)
|
||||
await q
|
||||
dreamer = await dreamer.refresh()
|
||||
edited = 2
|
||||
else:
|
||||
# Create from scratch
|
||||
logger.info(f"Creating Dreamer for {twitch_id=} with {kwargs}")
|
||||
# TODO: Should be in a transaction.. actually let's add transactions to the middleware..
|
||||
profile_data = app[profiledatav]
|
||||
user_profile = await profile_data.UserProfileRow.create(nickname=name)
|
||||
await profile_data.TwitchProfileRow.create(
|
||||
profileid=user_profile.profileid,
|
||||
userid=twitch_id,
|
||||
)
|
||||
await data.UserPreferences.create(
|
||||
profileid=user_profile.profileid,
|
||||
twitch_name=name,
|
||||
preferences=prefs
|
||||
)
|
||||
dreamer = await data.Dreamer.fetch(user_profile.profileid)
|
||||
assert dreamer is not None
|
||||
edited = 1
|
||||
|
||||
self = cls(app, dreamer)
|
||||
if edited == 1:
|
||||
args = await self.event_log_args(title=f"User #{dreamer.user_id} created!")
|
||||
await event_log(**args.send_args)
|
||||
elif edited == 2:
|
||||
args = await self.event_log_args(title=f"User #{dreamer.user_id} updated!")
|
||||
await event_log(**args.send_args)
|
||||
return self
|
||||
|
||||
async def edit(self, **kwargs: Unpack[UserEditParams]):
|
||||
data = self.data
|
||||
# We can edit the name, and preferences
|
||||
prefs = await self.get_prefs()
|
||||
update_args = {}
|
||||
if 'name' in kwargs:
|
||||
update_args['twitch_name'] = kwargs['name']
|
||||
if 'preferences' in kwargs:
|
||||
update_args['preferences'] = kwargs['preferences']
|
||||
|
||||
if update_args:
|
||||
logger.info(f"Updating dreamer {self.row=} with {kwargs}")
|
||||
await prefs.update(**update_args)
|
||||
|
||||
args = await self.event_log_args(title=f"User #{self.row.user_id} updated!")
|
||||
await event_log(**args.send_args)
|
||||
|
||||
async def delete(self) -> UserDetailsPayload:
|
||||
payload = await self.prepare(details=True)
|
||||
# This will cascade to all other data the user has
|
||||
await self.profile_data.UserProfileRow.table.delete_where(profileid=self.row.user_id)
|
||||
# Make sure we take the user out of cache
|
||||
await self.row.refresh()
|
||||
return payload
|
||||
|
||||
async def get_wallet(self):
|
||||
query = self.data.Transaction.table.select_where(user_id=self.row.user_id)
|
||||
query.select(wallet="SUM(amount)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
|
||||
return results[0]['wallet']
|
||||
|
||||
async def get_specimen(self) -> Optional[Specimen]:
|
||||
data = self.data
|
||||
active_specrows = await data.Specimen.fetch_where(
|
||||
owner_id=self.row.user_id,
|
||||
forgotten_at=NULL
|
||||
)
|
||||
if active_specrows:
|
||||
row = active_specrows[0]
|
||||
spec = Specimen(self.app, row)
|
||||
else:
|
||||
spec = None
|
||||
return spec
|
||||
|
||||
async def get_inventory(self):
|
||||
return []
|
||||
|
||||
async def event_log_args(self, **kwargs) -> MessageArgs:
|
||||
desc = '\n'.join(await self.tabulate())
|
||||
embed = discord.Embed(description=desc, timestamp=self.row.created_at, **kwargs)
|
||||
embed.set_footer(text='Created At')
|
||||
|
||||
# TODO: We could add wallet, specimen, and inventory info here too
|
||||
return MessageArgs(embed=embed)
|
||||
|
||||
async def tabulate(self):
|
||||
"""
|
||||
Present the User as a discord-readable table.
|
||||
"""
|
||||
table = {
|
||||
'user_id': f"`{self.row.user_id}`",
|
||||
'twitch_id': f"`{self.row.twitch_id}`" if self.row.twitch_id else 'No Twitch linked',
|
||||
'name': f"`{self.row.name}`",
|
||||
'preferences': f"`{self.row.preferences}`",
|
||||
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||
}
|
||||
return tabulate(*table.items())
|
||||
|
||||
|
||||
@overload
|
||||
async def prepare(self, details: Literal[True]=True) -> UserDetailsPayload:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def prepare(self, details: Literal[False]=False) -> UserPayload:
|
||||
...
|
||||
|
||||
async def prepare(self, details=False) -> UserPayload | UserDetailsPayload:
|
||||
# Since we are working with view rows, make sure we refresh
|
||||
row = self.row
|
||||
await row.refresh()
|
||||
|
||||
base_user: UserPayload = {
|
||||
'user_id': row.user_id,
|
||||
'twitch_id': str(row.twitch_id) if row.twitch_id else None,
|
||||
'name': row.name,
|
||||
'preferences': row.preferences,
|
||||
'created_at': row.created_at.isoformat(),
|
||||
}
|
||||
|
||||
if details:
|
||||
# Now add details
|
||||
specimen = await self.get_specimen()
|
||||
sp_payload = await specimen.prepare() if specimen is not None else None
|
||||
inventory = [await item.prepare() for item in await self.get_inventory()]
|
||||
user: UserPayload = base_user | {
|
||||
'specimen': sp_payload,
|
||||
'inventory': inventory,
|
||||
'wallet': await self.get_wallet(),
|
||||
}
|
||||
else:
|
||||
user = base_user
|
||||
logger.debug(f"User prepared: {user}")
|
||||
return user
|
||||
|
||||
|
||||
@routes.view('/users')
|
||||
@routes.view('/users/', name='users')
|
||||
class UsersView(web.View):
|
||||
async def post(self):
|
||||
request = self.request
|
||||
|
||||
params = await request.json()
|
||||
for key in create_fields:
|
||||
if key in request:
|
||||
params.setdefault(key, request[key])
|
||||
|
||||
await User.validate_create_params(params)
|
||||
logger.info(f"Creating a new user with args: {params=}")
|
||||
user = await User.create(self.request.app, **params)
|
||||
logger.debug(f"Created user: {user!r}")
|
||||
payload = await user.prepare(details=True)
|
||||
return web.json_response(payload)
|
||||
|
||||
async def get(self):
|
||||
request = self.request
|
||||
filter_params = {}
|
||||
keys = [
|
||||
'user_id', 'twitch_id', 'name', 'created_before', 'created_after',
|
||||
]
|
||||
for key in keys:
|
||||
value = request.query.get(key, request.get(key, None))
|
||||
filter_params[key] = value
|
||||
|
||||
logger.info(f"Querying users with params: {filter_params=}")
|
||||
users = await User.query(request.app, **filter_params)
|
||||
payload = [await user.prepare(details=True) for user in users]
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.view('/users/{user_id}')
|
||||
@routes.view('/users/{user_id}/', name='user')
|
||||
class UserView(web.View):
|
||||
async def resolve_user(self):
|
||||
request = self.request
|
||||
user_id = request.match_info['user_id']
|
||||
user = await User.fetch_from_id(request.app, int(user_id))
|
||||
if user is None:
|
||||
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||
return user
|
||||
|
||||
async def get(self):
|
||||
user = await self.resolve_user()
|
||||
logger.info(f"Received GET for user {user=}")
|
||||
payload = await user.prepare(details=True)
|
||||
return web.json_response(payload)
|
||||
|
||||
async def patch(self):
|
||||
user = await self.resolve_user()
|
||||
params = await self.request.json()
|
||||
|
||||
edit_data = {}
|
||||
for key, value in params.items():
|
||||
if key not in edit_fields:
|
||||
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
|
||||
edit_data[key] = value
|
||||
|
||||
for key in edit_fields:
|
||||
if key in self.request:
|
||||
edit_data.setdefault(key, self.request[key])
|
||||
|
||||
logger.info(f"Received PATCH for user {user} with params: {params}")
|
||||
await user.edit(**edit_data)
|
||||
payload = await user.prepare(details=True)
|
||||
return web.json_response(payload)
|
||||
|
||||
async def delete(self):
|
||||
user = await self.resolve_user()
|
||||
logger.info(f"Received DELETE for user {user}")
|
||||
payload = await user.delete()
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@routes.route('*', "/users/{user_id}{tail:/events}")
|
||||
@routes.route('*', "/users/{user_id}{tail:/events/.*}")
|
||||
@routes.route('*', "/users/{user_id}{tail:/transactions}")
|
||||
@routes.route('*', "/users/{user_id}{tail:/transactions/.*}")
|
||||
@routes.route('*', "/users/{user_id}{tail:/specimens}")
|
||||
@routes.route('*', "/users/{user_id}{tail:/specimens/.*}")
|
||||
async def user_prefix_routes(request: web.Request):
|
||||
user_id = int(request.match_info['user_id'])
|
||||
user = await User.fetch_from_id(request.app, user_id)
|
||||
if user is None:
|
||||
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||
|
||||
new_path = request.match_info['tail']
|
||||
logger.info(f"Redirecting {request=} to {new_path=} and setting {user_id=}")
|
||||
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = user_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
|
||||
|
||||
@routes.route('*', "/users/{user_id}/specimen")
|
||||
@routes.route('*', "/users/{user_id}/specimen{tail:/.*}")
|
||||
async def user_specimen_route(request: web.Request):
|
||||
user_id = int(request.match_info['user_id'])
|
||||
user = await User.fetch_from_id(request.app, user_id)
|
||||
if user is None:
|
||||
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||
tail = request.match_info.get('tail', '')
|
||||
|
||||
specimen = await user.get_specimen()
|
||||
if request.method == 'POST' and not tail.strip('/'):
|
||||
if specimen is None:
|
||||
# Redirect to POST /specimens
|
||||
# TODO: Would be nicer to use named handler here
|
||||
new_path = '/specimens'
|
||||
logger.info(f"Redirecting {request=} to POST /specimens")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = user_id
|
||||
new_request['owner_id'] = user_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
raise web.HTTPBadRequest(text="This user already has an active specimen!")
|
||||
elif specimen is None:
|
||||
raise web.HTTPNotFound(text="This user has no active specimen.")
|
||||
else:
|
||||
specimen_id = specimen.row.specimen_id
|
||||
# Redirect to POST /specimens/{specimen_id}/...
|
||||
new_path = f"/specimens/{specimen_id}".format(specimen_id=specimen_id) + tail
|
||||
logger.info(f"Redirecting {request=} to {new_path}")
|
||||
new_request = request.clone(rel_url=new_path)
|
||||
new_request['user_id'] = user_id
|
||||
new_request['owner_id'] = user_id
|
||||
new_request['specimen_id'] = specimen_id
|
||||
match_info = await request.app.router.resolve(new_request)
|
||||
new_request._match_info = match_info
|
||||
match_info.current_app = request.app
|
||||
if match_info.handler:
|
||||
return await match_info.handler(new_request)
|
||||
else:
|
||||
logger.info(f"Could not find handler matching {new_request}")
|
||||
raise web.HTTPNotFound()
|
||||
|
||||
|
||||
@routes.route('GET', "/users/{user_id}/wallet")
|
||||
@routes.route('GET', "/users/{user_id}/wallet/")
|
||||
async def user_wallet_route(request: web.Request):
|
||||
user_id = int(request.match_info['user_id'])
|
||||
user = await User.fetch_from_id(request.app, user_id)
|
||||
if user is None:
|
||||
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||
wallet = await user.get_wallet()
|
||||
return web.json_response(wallet)
|
||||
9
src/twitch/__init__.py
Normal file
9
src/twitch/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import TwitchAuthCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TwitchAuthCog(bot))
|
||||
|
||||
50
src/twitch/authclient.py
Normal file
50
src/twitch/authclient.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Testing client for the twitch AuthServer.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from twitchAPI.twitch import Twitch
|
||||
from twitchAPI.oauth import UserAuthenticator
|
||||
from twitchAPI.type import AuthScope
|
||||
|
||||
from meta.config import conf
|
||||
|
||||
|
||||
URI = "http://localhost:3000/twiauth/confirm"
|
||||
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
|
||||
|
||||
async def main():
|
||||
# Load in client id and secret
|
||||
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
|
||||
url = auth.return_auth_url()
|
||||
|
||||
# Post url to user
|
||||
print(url)
|
||||
|
||||
# Send listen request to server
|
||||
# Wait for listen request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
|
||||
await ws.send_json({'state': auth.state})
|
||||
result = await ws.receive_json()
|
||||
|
||||
# Hopefully get back code, print the response
|
||||
print(f"Recieved: {result}")
|
||||
|
||||
# Authorise with code and client details
|
||||
tokens = await auth.authenticate(user_token=result['code'])
|
||||
if tokens:
|
||||
token, refresh = tokens
|
||||
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
|
||||
print(f"Authorised!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
86
src/twitch/authserver.py
Normal file
86
src/twitch/authserver.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import logging
|
||||
import uuid
|
||||
import asyncio
|
||||
from contextvars import ContextVar
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
|
||||
|
||||
|
||||
class AuthServer:
|
||||
def __init__(self):
|
||||
self.listeners = {}
|
||||
|
||||
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
|
||||
args = request.query
|
||||
if 'state' not in args:
|
||||
raise web.HTTPBadRequest(text="No state provided.")
|
||||
if args['state'] not in self.listeners:
|
||||
raise web.HTTPBadRequest(text="Invalid state.")
|
||||
self.listeners[args['state']].set_result(dict(args))
|
||||
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
|
||||
|
||||
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
|
||||
_reqid = str(uuid.uuid1())
|
||||
reqid.set(_reqid)
|
||||
|
||||
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
|
||||
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Get the listen request data
|
||||
try:
|
||||
listen_req = await ws.receive_json(timeout=60)
|
||||
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
|
||||
if 'state' not in listen_req:
|
||||
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
|
||||
raise web.HTTPBadRequest(text="Listen request must include state string.")
|
||||
elif listen_req['state'] in self.listeners:
|
||||
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
|
||||
raise web.HTTPBadRequest(text="Invalid state string.")
|
||||
except ValueError:
|
||||
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
|
||||
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||
except TypeError:
|
||||
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
|
||||
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
|
||||
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
|
||||
except Exception:
|
||||
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
|
||||
raise web.HTTPInternalServerError()
|
||||
|
||||
try:
|
||||
fut = self.listeners[listen_req['state']] = asyncio.Future()
|
||||
result = await asyncio.wait_for(fut, timeout=120)
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
|
||||
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
|
||||
finally:
|
||||
self.listeners.pop(listen_req['state'], None)
|
||||
|
||||
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
|
||||
await ws.send_json(result)
|
||||
await ws.close()
|
||||
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
|
||||
|
||||
return ws
|
||||
|
||||
def main(argv):
|
||||
app = web.Application()
|
||||
server = AuthServer()
|
||||
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
|
||||
app.router.add_get("/twiauth/listen", server.handle_listen_request)
|
||||
|
||||
logger.info("App setup and configured. Starting now.")
|
||||
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
main(sys.argv)
|
||||
113
src/twitch/cog.py
Normal file
113
src/twitch/cog.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
from twitchAPI.oauth import UserAuthenticator
|
||||
from twitchAPI.twitch import Twitch
|
||||
from twitchAPI.type import AuthScope
|
||||
from twitchio.ext import commands
|
||||
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import LionCog, LionBot, CrocBot
|
||||
from meta.LionContext import LionContext
|
||||
from twitch.userflow import UserAuthFlow
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import TwitchAuthData
|
||||
|
||||
|
||||
class TwitchAuthCog(LionCog):
|
||||
DEFAULT_SCOPES = []
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(TwitchAuthData())
|
||||
|
||||
self.client_cache = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
# ----- Auth API -----
|
||||
|
||||
async def fetch_client_for(self, userid: str):
|
||||
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||
if authrow is None:
|
||||
# TODO: Some user authentication error
|
||||
self.client_cache.pop(userid, None)
|
||||
raise ValueError("Requested user is not authenticated.")
|
||||
if (twitch := self.client_cache.get(userid)) is None:
|
||||
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
authscopes = [AuthScope(scope) for scope in scopes]
|
||||
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
|
||||
self.client_cache[userid] = twitch
|
||||
return twitch
|
||||
|
||||
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||
"""
|
||||
Checks whether the given userid is authorised.
|
||||
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||
"""
|
||||
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||
if authrow:
|
||||
if scopes:
|
||||
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
desired = {scope.value for scope in scopes}
|
||||
has_auth = desired.issubset(has_scopes)
|
||||
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
|
||||
else:
|
||||
has_auth = True
|
||||
else:
|
||||
has_auth = False
|
||||
return has_auth
|
||||
|
||||
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||
"""
|
||||
Start the user authentication flow for the given userid.
|
||||
Will request the given scopes along with the default ones and any existing scopes.
|
||||
"""
|
||||
self.client_cache.pop(userid, None)
|
||||
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
existing = map(AuthScope, existing_strs)
|
||||
to_request = set(existing).union(scopes)
|
||||
return await self.start_auth(to_request)
|
||||
|
||||
async def start_auth(self, scopes = []):
|
||||
# TODO: Work out a way to just clone the current twitch object
|
||||
# Or can we otherwise build UserAuthenticator without app auth?
|
||||
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||
await flow.setup()
|
||||
|
||||
return flow
|
||||
|
||||
# ----- Commands -----
|
||||
@cmds.hybrid_command(name='auth')
|
||||
async def cmd_auth(self, ctx: LionContext):
|
||||
if ctx.interaction:
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
flow = await self.start_auth()
|
||||
await ctx.reply(flow.auth.return_auth_url())
|
||||
await flow.run()
|
||||
await ctx.reply("Authentication Complete!")
|
||||
|
||||
@cmds.hybrid_command(name='modauth')
|
||||
async def cmd_modauth(self, ctx: LionContext):
|
||||
if ctx.interaction:
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
scopes = [
|
||||
AuthScope.MODERATOR_READ_FOLLOWERS,
|
||||
AuthScope.CHANNEL_READ_REDEMPTIONS,
|
||||
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
|
||||
]
|
||||
flow = await self.start_auth(scopes=scopes)
|
||||
await ctx.reply(flow.auth.return_auth_url())
|
||||
await flow.run()
|
||||
await ctx.reply("Authentication Complete!")
|
||||
79
src/twitch/data.py
Normal file
79
src/twitch/data.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import datetime as dt
|
||||
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class TwitchAuthData(Registry):
|
||||
class UserAuthRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE twitch_user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
access_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
obtained_at TIMESTAMPTZ
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'twitch_user_auth'
|
||||
_cache_ = {}
|
||||
|
||||
userid = Integer(primary=True)
|
||||
access_token = String()
|
||||
refresh_token = String()
|
||||
expires_at = Timestamp()
|
||||
obtained_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def update_user_auth(
|
||||
cls, userid: str, token: str, refresh: str,
|
||||
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||
scopes: list[str]
|
||||
):
|
||||
if cls._connector is None:
|
||||
raise ValueError("Attempting to use uninitialised Registry.")
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# Clear row for this userid
|
||||
await cls.table.delete_where(userid=userid)
|
||||
|
||||
# Insert new user row
|
||||
row = await cls.create(
|
||||
userid=userid,
|
||||
access_token=token,
|
||||
refresh_token=refresh,
|
||||
expires_at=expires_at,
|
||||
obtained_at=obtained_at
|
||||
)
|
||||
# Insert new scope rows
|
||||
if scopes:
|
||||
await TwitchAuthData.user_scopes.insert_many(
|
||||
('userid', 'scope'),
|
||||
*((userid, scope) for scope in scopes)
|
||||
)
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||
"""
|
||||
Get a list of scopes stored for the given user.
|
||||
Will return an empty list if the user is not authenticated.
|
||||
"""
|
||||
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||
|
||||
return [row['scope'] for row in rows] if rows else []
|
||||
|
||||
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE twitch_user_scopes(
|
||||
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
scope TEXT
|
||||
);
|
||||
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||
"""
|
||||
user_scopes = Table('twitch_user_scopes')
|
||||
0
src/twitch/lib.py
Normal file
0
src/twitch/lib.py
Normal file
88
src/twitch/userflow.py
Normal file
88
src/twitch/userflow.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from typing import Optional
|
||||
import datetime as dt
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import aiohttp
|
||||
from twitchAPI.twitch import Twitch
|
||||
from twitchAPI.oauth import UserAuthenticator, validate_token
|
||||
from twitchAPI.type import AuthType
|
||||
from twitchio.client import asyncio
|
||||
|
||||
from meta.errors import SafeCancellation
|
||||
from utils.lib import utc_now
|
||||
from .data import TwitchAuthData
|
||||
from . import logger
|
||||
|
||||
class UserAuthFlow:
|
||||
auth: UserAuthenticator
|
||||
data: TwitchAuthData
|
||||
auth_ws: str
|
||||
|
||||
def __init__(self, data, auth, auth_ws):
|
||||
self.auth = auth
|
||||
self.data = data
|
||||
self.auth_ws = auth_ws
|
||||
|
||||
self._setup_done = asyncio.Event()
|
||||
self._comm_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Establishes websocket connection to the AuthServer,
|
||||
and requests listening for the given state.
|
||||
Propagates any exceptions that occur during connection setup.
|
||||
"""
|
||||
if self._setup_done.is_set():
|
||||
raise ValueError("UserAuthFlow is already set up.")
|
||||
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
|
||||
await self._setup_done.wait()
|
||||
if self._comm_task.done() and (exc := self._comm_task.exception()):
|
||||
raise exc
|
||||
|
||||
async def _communicate(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(self.auth_ws) as ws:
|
||||
await ws.send_json({'state': self.auth.state})
|
||||
self._setup_done.set()
|
||||
return await ws.receive_json()
|
||||
|
||||
async def run(self) -> TwitchAuthData.UserAuthRow:
|
||||
if not self._setup_done.is_set():
|
||||
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||
if self._comm_task is None:
|
||||
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||
|
||||
result = await self._comm_task
|
||||
if result.get('error', None):
|
||||
# TODO Custom auth errors
|
||||
# This is only documented to occur when the user denies the auth
|
||||
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||
|
||||
if result.get('state', None) != self.auth.state:
|
||||
# This should never happen unless the authserver has its wires crossed somehow,
|
||||
# or the connection has been tampered with.
|
||||
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||
logger.critical(
|
||||
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
|
||||
)
|
||||
raise SafeCancellation(
|
||||
"Could not complete authentication! Invalid server response."
|
||||
)
|
||||
|
||||
# Now assume result has a valid code
|
||||
# Exchange code for an auth token and a refresh token
|
||||
# Ignore type here, authenticate returns None if a callback function has been given.
|
||||
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
|
||||
|
||||
# Fetch the associated userid and basic info
|
||||
v_result = await validate_token(token)
|
||||
userid = v_result['user_id']
|
||||
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||
|
||||
# Save auth data
|
||||
return await self.data.UserAuthRow.update_user_auth(
|
||||
userid=userid, token=token, refresh=refresh,
|
||||
expires_at=expiry, obtained_at=utc_now(),
|
||||
scopes=[scope.value for scope in self.auth.scopes]
|
||||
)
|
||||
27
src/utils/auth.py
Normal file
27
src/utils/auth.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Awaitable, Callable
|
||||
from aiohttp import hdrs, web
|
||||
|
||||
CHALLENGE = web.Response(
|
||||
body="<b> 401 UNAUTHORIZED </b>",
|
||||
status=401,
|
||||
reason='UNAUTHORIZED',
|
||||
headers={
|
||||
hdrs.WWW_AUTHENTICATE: 'X-API-KEY',
|
||||
hdrs.CONTENT_TYPE: 'text/html; charset=utf-8',
|
||||
hdrs.CONNECTION: 'close',
|
||||
},
|
||||
)
|
||||
|
||||
def key_auth_factory(required_token: str):
|
||||
"""
|
||||
Creates an aiohttp middleware that ensures
|
||||
the `required_token` is provided in the `X-API-KEY` header.
|
||||
"""
|
||||
@web.middleware
|
||||
async def key_auth_middleware(request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]]) -> web.StreamResponse:
|
||||
auth_header = request.headers.get('X-API-KEY')
|
||||
if not auth_header or auth_header.strip() != required_token:
|
||||
return CHALLENGE
|
||||
else:
|
||||
return await handler(request)
|
||||
return key_auth_middleware
|
||||
Reference in New Issue
Block a user