Compare commits
19 Commits
5efcdd6709
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 445eccccd6 | |||
| 6ec500ec87 | |||
| 8e2bd67efc | |||
| 092e818990 | |||
| 2bf95beaae | |||
| 250b55634d | |||
| 9e5c2f5777 | |||
| e1a1f7d4fe | |||
| c3ed48e918 | |||
| 48a01a2861 | |||
| d83709d2c2 | |||
| 7f977f90e8 | |||
| 04b6dcbc3f | |||
| c07577cc0a | |||
| dc551b34a9 | |||
| 94bc8b6c21 | |||
| aba73b8bba | |||
| 77dc90cc32 | |||
| a02cc0977a |
@@ -189,9 +189,10 @@ CREATE TABLE stamp_types (
|
|||||||
|
|
||||||
CREATE TABLE documents (
|
CREATE TABLE documents (
|
||||||
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
document_data VARCHAR NOT NULL,
|
document_data TEXT NOT NULL,
|
||||||
seal INTEGER NOT NULL,
|
seal INTEGER NOT NULL,
|
||||||
metadata TEXT
|
metadata TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE document_stamps (
|
CREATE TABLE document_stamps (
|
||||||
@@ -230,14 +231,14 @@ CREATE TABLE plain_events (
|
|||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
||||||
message TEXT NOT NULL,
|
message TEXT NOT NULL,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE raid_events (
|
CREATE TABLE raid_events (
|
||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
||||||
visitor_count INTEGER NOT NULL,
|
visitor_count INTEGER NOT NULL,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE cheer_events (
|
CREATE TABLE cheer_events (
|
||||||
@@ -246,7 +247,7 @@ CREATE TABLE cheer_events (
|
|||||||
amount INTEGER NOT NULL,
|
amount INTEGER NOT NULL,
|
||||||
cheer_type TEXT,
|
cheer_type TEXT,
|
||||||
message TEXT,
|
message TEXT,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE subscriber_events (
|
CREATE TABLE subscriber_events (
|
||||||
@@ -255,7 +256,7 @@ CREATE TABLE subscriber_events (
|
|||||||
subscribed_length INTEGER NOT NULL,
|
subscribed_length INTEGER NOT NULL,
|
||||||
tier INTEGER NOT NULL,
|
tier INTEGER NOT NULL,
|
||||||
message TEXT,
|
message TEXT,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
@@ -276,12 +277,14 @@ SELECT
|
|||||||
subscriber_events.subscribed_length AS subscriber_length,
|
subscriber_events.subscribed_length AS subscriber_length,
|
||||||
subscriber_events.tier AS subscriber_tier,
|
subscriber_events.tier AS subscriber_tier,
|
||||||
subscriber_events.message AS subscriber_message,
|
subscriber_events.message AS subscriber_message,
|
||||||
|
documents.seal AS document_seal
|
||||||
FROM
|
FROM
|
||||||
events
|
events
|
||||||
LEFT JOIN plain_events USING (event_id)
|
LEFT JOIN plain_events USING (event_id)
|
||||||
LEFT JOIN raid_events USING (event_id)
|
LEFT JOIN raid_events USING (event_id)
|
||||||
LEFT JOIN cheer_events USING (event_id)
|
LEFT JOIN cheer_events USING (event_id)
|
||||||
LEFT JOIN subscriber_events USING (event_id)
|
LEFT JOIN subscriber_events USING (event_id)
|
||||||
|
LEFT JOIN documents USING (document_id)
|
||||||
ORDER BY events.occurred_at ASC;
|
ORDER BY events.occurred_at ASC;
|
||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
@@ -294,6 +297,7 @@ CREATE TABLE user_specimens (
|
|||||||
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
forgotten_at TIMESTAMPTZ
|
forgotten_at TIMESTAMPTZ
|
||||||
);
|
);
|
||||||
|
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
|
||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
|
|||||||
61
src/api.py
Normal file
61
src/api.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from meta import conf
|
||||||
|
from data import Database
|
||||||
|
from utils.auth import key_auth_factory
|
||||||
|
from datamodels import DataModel
|
||||||
|
from constants import DATA_VERSION
|
||||||
|
|
||||||
|
from modules.profiles.data import ProfileData
|
||||||
|
from routes import dbvar, datamodelsv, profiledatav, register_routes
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def attach_db(app: web.Application):
|
||||||
|
db = Database(conf.data['args'])
|
||||||
|
async with db.open():
|
||||||
|
version = await db.version()
|
||||||
|
if version.version != DATA_VERSION:
|
||||||
|
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||||
|
logger.critical(error)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
datamodel = DataModel()
|
||||||
|
db.load_registry(datamodel)
|
||||||
|
await datamodel.init()
|
||||||
|
|
||||||
|
profiledata = ProfileData()
|
||||||
|
db.load_registry(profiledata)
|
||||||
|
await profiledata.init()
|
||||||
|
|
||||||
|
app[dbvar] = db
|
||||||
|
app[datamodelsv] = datamodel
|
||||||
|
app[profiledatav] = profiledata
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
async def test(request: web.Request) -> web.Response:
|
||||||
|
return web.Response(text="Welcome to the Dreamspace API. Please donate an important childhood memory to continue.")
|
||||||
|
|
||||||
|
def app_factory():
|
||||||
|
auth = key_auth_factory(conf.API['TOKEN'])
|
||||||
|
app = web.Application(middlewares=[auth])
|
||||||
|
app.cleanup_ctx.append(attach_db)
|
||||||
|
app.router.add_get('/', test)
|
||||||
|
register_routes(app.router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app = app_factory()
|
||||||
|
web.run_app(app, port=int(conf.API['PORT']))
|
||||||
|
|
||||||
6
src/brand.py
Normal file
6
src/brand.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import discord
|
||||||
|
|
||||||
|
|
||||||
|
# Theme
|
||||||
|
MAIN_COLOUR = discord.Colour.from_str('#11EA11')
|
||||||
|
ACCENT_COLOUR = discord.Colour.from_str('#EA11EA')
|
||||||
@@ -11,6 +11,7 @@ from meta.app import shardname, appname
|
|||||||
from meta.logger import log_wrap
|
from meta.logger import log_wrap
|
||||||
from utils.lib import utc_now
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from datamodels import DataModel
|
||||||
from .data import CoreData
|
from .data import CoreData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -29,7 +30,9 @@ class CoreCog(LionCog):
|
|||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.data = CoreData()
|
self.data = CoreData()
|
||||||
|
self.datamodel = DataModel()
|
||||||
bot.db.load_registry(self.data)
|
bot.db.load_registry(self.data)
|
||||||
|
bot.db.load_registry(self.datamodel)
|
||||||
|
|
||||||
self.app_config: Optional[CoreData.AppConfig] = None
|
self.app_config: Optional[CoreData.AppConfig] = None
|
||||||
self.bot_config: Optional[CoreData.BotConfig] = None
|
self.bot_config: Optional[CoreData.BotConfig] = None
|
||||||
@@ -43,6 +46,9 @@ class CoreCog(LionCog):
|
|||||||
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||||
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||||
|
|
||||||
|
await self.data.init()
|
||||||
|
await self.datamodel.init()
|
||||||
|
|
||||||
# Load the app command cache
|
# Load the app command cache
|
||||||
await self.reload_appcmd_cache()
|
await self.reload_appcmd_cache()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
from data import Registry, RowModel, Table, RegisterEnum
|
from data import Registry, RowModel, Table, RegisterEnum
|
||||||
from data.columns import Integer, String, Timestamp, Column
|
from data.columns import Integer, String, Timestamp, Column
|
||||||
|
|
||||||
@@ -9,6 +13,52 @@ class EventType(Enum):
|
|||||||
CHEER = 'cheer',
|
CHEER = 'cheer',
|
||||||
PLAIN = 'plain',
|
PLAIN = 'plain',
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
if self is EventType.SUBSCRIBER:
|
||||||
|
info = EventTypeInfo(
|
||||||
|
EventType.SUBSCRIBER,
|
||||||
|
DataModel.subscriber_events,
|
||||||
|
("tier", "subscribed_length", "message"),
|
||||||
|
("tier", "subscribed_length", "message"),
|
||||||
|
('subscriber_tier', 'subscriber_length', 'subscriber_message'),
|
||||||
|
)
|
||||||
|
elif self is EventType.RAID:
|
||||||
|
info = EventTypeInfo(
|
||||||
|
EventType.RAID,
|
||||||
|
DataModel.raid_events,
|
||||||
|
('visitor_count',),
|
||||||
|
('viewer_count',),
|
||||||
|
('raid_visitor_count',),
|
||||||
|
)
|
||||||
|
elif self is EventType.CHEER:
|
||||||
|
info = EventTypeInfo(
|
||||||
|
EventType.CHEER,
|
||||||
|
DataModel.cheer_events,
|
||||||
|
('amount', 'cheer_type', 'message'),
|
||||||
|
('amount', 'cheer_type', 'message'),
|
||||||
|
('cheer_amount', 'cheer_type', 'cheer_message'),
|
||||||
|
)
|
||||||
|
elif self is EventType.PLAIN:
|
||||||
|
info = EventTypeInfo(
|
||||||
|
EventType.PLAIN,
|
||||||
|
DataModel.plain_events,
|
||||||
|
('message',),
|
||||||
|
('message',),
|
||||||
|
('plain_message',),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected event type.")
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class EventTypeInfo(NamedTuple):
|
||||||
|
typ: EventType
|
||||||
|
table: Table
|
||||||
|
columns: tuple[str, ...]
|
||||||
|
params: tuple[str, ...]
|
||||||
|
detailcolumns: tuple[str, ...]
|
||||||
|
|
||||||
|
|
||||||
class DataModel(Registry):
|
class DataModel(Registry):
|
||||||
_EventType = RegisterEnum(EventType, 'EventType')
|
_EventType = RegisterEnum(EventType, 'EventType')
|
||||||
@@ -46,10 +96,10 @@ class DataModel(Registry):
|
|||||||
LEFT JOIN profiles_twitch USING (profileid)
|
LEFT JOIN profiles_twitch USING (profileid)
|
||||||
LEFT JOIN user_preferences USING (profileid);
|
LEFT JOIN user_preferences USING (profileid);
|
||||||
"""
|
"""
|
||||||
_tablename_ = ''
|
_tablename_ = 'dreamers'
|
||||||
_readonly_ = True
|
_readonly_ = True
|
||||||
|
|
||||||
profileid = Integer(primary=True)
|
user_id = Integer(primary=True)
|
||||||
name = String()
|
name = String()
|
||||||
twitch_id = Integer()
|
twitch_id = Integer()
|
||||||
preferences = String()
|
preferences = String()
|
||||||
@@ -103,9 +153,10 @@ class DataModel(Registry):
|
|||||||
------
|
------
|
||||||
CREATE TABLE documents (
|
CREATE TABLE documents (
|
||||||
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
document_id INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
document_data VARCHAR NOT NULL,
|
document_data TEXT NOT NULL,
|
||||||
seal INTEGER NOT NULL,
|
seal INTEGER NOT NULL,
|
||||||
metadata TEXT
|
metadata TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
_tablename_ = 'documents'
|
_tablename_ = 'documents'
|
||||||
@@ -115,6 +166,15 @@ class DataModel(Registry):
|
|||||||
document_data = Column()
|
document_data = Column()
|
||||||
seal = Integer()
|
seal = Integer()
|
||||||
metadata = String()
|
metadata = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
"""
|
||||||
|
Helper method to decode the saved document data to a byte string.
|
||||||
|
This may fail if the saved string is not base64 encoded.
|
||||||
|
"""
|
||||||
|
byts = BytesIO(base64.b64decode(self.document_data))
|
||||||
|
return byts
|
||||||
|
|
||||||
class DocumentStamp(RowModel):
|
class DocumentStamp(RowModel):
|
||||||
"""
|
"""
|
||||||
@@ -180,15 +240,15 @@ class DataModel(Registry):
|
|||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
|
||||||
message TEXT NOT NULL,
|
message TEXT NOT NULL,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE raid_events (
|
CREATE TABLE raid_events (
|
||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
|
||||||
visitor_count INTEGER NOT NULL,
|
visitor_count INTEGER NOT NULL,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE cheer_events (
|
CREATE TABLE cheer_events (
|
||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
@@ -196,8 +256,8 @@ class DataModel(Registry):
|
|||||||
amount INTEGER NOT NULL,
|
amount INTEGER NOT NULL,
|
||||||
cheer_type TEXT,
|
cheer_type TEXT,
|
||||||
message TEXT,
|
message TEXT,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE subscriber_events (
|
CREATE TABLE subscriber_events (
|
||||||
event_id integer PRIMARY KEY,
|
event_id integer PRIMARY KEY,
|
||||||
@@ -205,7 +265,7 @@ class DataModel(Registry):
|
|||||||
subscribed_length INTEGER NOT NULL,
|
subscribed_length INTEGER NOT NULL,
|
||||||
tier INTEGER NOT NULL,
|
tier INTEGER NOT NULL,
|
||||||
message TEXT,
|
message TEXT,
|
||||||
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type)
|
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE VIEW event_details AS
|
CREATE VIEW event_details AS
|
||||||
@@ -225,12 +285,14 @@ class DataModel(Registry):
|
|||||||
subscriber_events.subscribed_length AS subscriber_length,
|
subscriber_events.subscribed_length AS subscriber_length,
|
||||||
subscriber_events.tier AS subscriber_tier,
|
subscriber_events.tier AS subscriber_tier,
|
||||||
subscriber_events.message AS subscriber_message,
|
subscriber_events.message AS subscriber_message,
|
||||||
|
documents.seal AS document_seal
|
||||||
FROM
|
FROM
|
||||||
events
|
events
|
||||||
LEFT JOIN plain_events USING (event_id)
|
LEFT JOIN plain_events USING (event_id)
|
||||||
LEFT JOIN raid_events USING (event_id)
|
LEFT JOIN raid_events USING (event_id)
|
||||||
LEFT JOIN cheer_events USING (event_id)
|
LEFT JOIN cheer_events USING (event_id)
|
||||||
LEFT JOIN subscriber_events USING (event_id)
|
LEFT JOIN subscriber_events USING (event_id)
|
||||||
|
LEFT JOIN documents USING (document_id)
|
||||||
ORDER BY events.occurred_at ASC;
|
ORDER BY events.occurred_at ASC;
|
||||||
"""
|
"""
|
||||||
_tablename_ = 'event_details'
|
_tablename_ = 'event_details'
|
||||||
@@ -251,6 +313,7 @@ class DataModel(Registry):
|
|||||||
subscriber_length = Integer()
|
subscriber_length = Integer()
|
||||||
subscriber_tier = Integer()
|
subscriber_tier = Integer()
|
||||||
subscriber_message = String()
|
subscriber_message = String()
|
||||||
|
document_seal = Integer()
|
||||||
|
|
||||||
|
|
||||||
class Specimen(RowModel):
|
class Specimen(RowModel):
|
||||||
@@ -263,6 +326,7 @@ class DataModel(Registry):
|
|||||||
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
born_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
forgotten_at TIMESTAMPTZ
|
forgotten_at TIMESTAMPTZ
|
||||||
);
|
);
|
||||||
|
CREATE UNIQUE INDEX ON user_specimens (owner_id) WHERE forgotten_at IS NULL;
|
||||||
"""
|
"""
|
||||||
_tablename_ = 'user_specimens'
|
_tablename_ = 'user_specimens'
|
||||||
_cache_ = {}
|
_cache_ = {}
|
||||||
|
|||||||
76
src/meta/CrocBot.py
Normal file
76
src/meta/CrocBot.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio.ext import pubsub
|
||||||
|
from twitchio.ext.commands.core import itertools
|
||||||
|
|
||||||
|
from data import Database
|
||||||
|
|
||||||
|
from .config import Conf
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CrocBot(commands.Bot):
|
||||||
|
def __init__(self, *args,
|
||||||
|
config: Conf,
|
||||||
|
data: Database,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.config = config
|
||||||
|
self.data = data
|
||||||
|
self.pubsub = pubsub.PubSubPool(self)
|
||||||
|
|
||||||
|
self._member_cache = defaultdict(dict)
|
||||||
|
|
||||||
|
async def event_ready(self):
|
||||||
|
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
|
|
||||||
|
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
|
||||||
|
self._member_cache[channel.name][user.name] = user
|
||||||
|
|
||||||
|
async def event_message(self, message: twitchio.Message):
|
||||||
|
if message.channel and message.author:
|
||||||
|
self._member_cache[message.channel.name][message.author.name] = message.author
|
||||||
|
await self.handle_commands(message)
|
||||||
|
|
||||||
|
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
|
||||||
|
if userstr.startswith('@'):
|
||||||
|
matching = False
|
||||||
|
userstr = userstr.strip('@ ')
|
||||||
|
|
||||||
|
result = None
|
||||||
|
if matching and len(userstr) >= 3:
|
||||||
|
lowered = userstr.lower()
|
||||||
|
full_matches = []
|
||||||
|
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
|
||||||
|
matchstr = user.name.lower()
|
||||||
|
print(matchstr)
|
||||||
|
if matchstr.startswith(lowered):
|
||||||
|
result = user
|
||||||
|
break
|
||||||
|
if lowered in matchstr:
|
||||||
|
full_matches.append(user)
|
||||||
|
if result is None and full_matches:
|
||||||
|
result = full_matches[0]
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
lookup = userstr
|
||||||
|
elif result.id is None:
|
||||||
|
lookup = result.name
|
||||||
|
else:
|
||||||
|
lookup = None
|
||||||
|
|
||||||
|
if lookup:
|
||||||
|
found = await self.fetch_users(names=[lookup])
|
||||||
|
if found:
|
||||||
|
result = found[0]
|
||||||
|
|
||||||
|
# No matches found
|
||||||
|
return result
|
||||||
@@ -3,6 +3,7 @@ this_package = 'modules'
|
|||||||
active = [
|
active = [
|
||||||
'.profiles',
|
'.profiles',
|
||||||
'.sysadmin',
|
'.sysadmin',
|
||||||
|
'.dreamspace',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
8
src/modules/dreamspace/__init__.py
Normal file
8
src/modules/dreamspace/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
from .cog import DreamCog
|
||||||
|
await bot.add_cog(DreamCog(bot))
|
||||||
70
src/modules/dreamspace/cog.py
Normal file
70
src/modules/dreamspace/cog.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import app_commands as appcmds
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
from meta import LionCog, LionBot, LionContext
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .ui.docviewer import DocumentViewer
|
||||||
|
|
||||||
|
|
||||||
|
class DreamCog(LionCog):
|
||||||
|
"""
|
||||||
|
Discord-facting interface for Dreamspace Adventures
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = bot.core.datamodel
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@log_wrap(action="Dreamer migration")
|
||||||
|
async def migrate_dreamer(self, source_profile, target_profile):
|
||||||
|
"""
|
||||||
|
Called when two dreamer profiles need to merge.
|
||||||
|
|
||||||
|
For example, when a user links a second twitch profile.
|
||||||
|
|
||||||
|
:TODO-MARKER:
|
||||||
|
Most of the migration logic is simple, e.g. just update the profileid
|
||||||
|
on the old events to the new profile.
|
||||||
|
The same applies to transactions and probably to inventory items.
|
||||||
|
However, there are some subtle choices to make, such as what to do
|
||||||
|
if both the old and the new profile have an active specimen?
|
||||||
|
A profile can only have one active specimen at a time.
|
||||||
|
There is also the question of how to merge user preferences, when those exist.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# User command: view their dreamer card, wallet inventory etc
|
||||||
|
|
||||||
|
# (Admin): View events/documents matching certain criteria
|
||||||
|
|
||||||
|
# (User): View own event cards with info?
|
||||||
|
# Let's make a demo viewer which lists their event cards and let's them open one via select?
|
||||||
|
# /documents -> Show a paged list of documents, select option displays the document in a viewer
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='documents',
|
||||||
|
description="View your printer log!"
|
||||||
|
)
|
||||||
|
async def documents_cmd(self, ctx: LionContext):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
events = await self.data.Events.fetch_where(user_id=profile.profileid)
|
||||||
|
docids = [event.document_id for event in events if event.document_id is not None]
|
||||||
|
if not docids:
|
||||||
|
await ctx.error_reply("You don't have any documents yet!")
|
||||||
|
return
|
||||||
|
|
||||||
|
view = DocumentViewer(self.bot, ctx.interaction.user.id, filter=(self.data.Document.document_id == docids))
|
||||||
|
await view.run(ctx.interaction)
|
||||||
|
|
||||||
|
# (User): View Specimen information
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
36
src/modules/dreamspace/events.py
Normal file
36
src/modules/dreamspace/events.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from datamodels import DataModel
|
||||||
|
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
_typs = {}
|
||||||
|
|
||||||
|
def __init__(self, event_row: DataModel.Events, **kwargs):
|
||||||
|
self.row = event_row
|
||||||
|
|
||||||
|
def __getattribute__(self, name: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_document(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_user(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Document:
|
||||||
|
def as_bytes(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_stamps(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def refresh(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Stamp:
|
||||||
|
...
|
||||||
171
src/modules/dreamspace/ui/docviewer.py
Normal file
171
src/modules/dreamspace/ui/docviewer.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import binascii
|
||||||
|
from itertools import chain
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import asyncio
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ui.select import select, Select, UserSelect
|
||||||
|
from discord.ui.button import button, Button
|
||||||
|
from discord.ui.text_input import TextInput
|
||||||
|
from discord.enums import ButtonStyle, TextStyle
|
||||||
|
from discord.components import SelectOption
|
||||||
|
|
||||||
|
from datamodels import DataModel
|
||||||
|
from meta import LionBot, conf
|
||||||
|
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||||
|
from data import ORDER, Condition
|
||||||
|
|
||||||
|
from utils.ui import MessageUI, input
|
||||||
|
from utils.lib import MessageArgs, tabulate, utc_now
|
||||||
|
|
||||||
|
from .. import logger
|
||||||
|
|
||||||
|
class DocumentViewer(MessageUI):
|
||||||
|
"""
|
||||||
|
Simple pager which displays a filtered list of Documents.
|
||||||
|
"""
|
||||||
|
block_len = 5
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
|
||||||
|
super().__init__(callerid=callerid, **kwargs)
|
||||||
|
|
||||||
|
self.bot = bot
|
||||||
|
self.data: DataModel = bot.core.datamodel
|
||||||
|
self.filter = filter
|
||||||
|
|
||||||
|
# Paging state
|
||||||
|
self._pagen = 0
|
||||||
|
self.blocks = [[]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def page_count(self):
|
||||||
|
return len(self.blocks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pagen(self):
|
||||||
|
self._pagen %= self.page_count
|
||||||
|
return self._pagen
|
||||||
|
|
||||||
|
@pagen.setter
|
||||||
|
def pagen(self, value):
|
||||||
|
self._pagen = value % self.page_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_page(self):
|
||||||
|
return self.blocks[self.pagen]
|
||||||
|
|
||||||
|
# ----- UI Components -----
|
||||||
|
|
||||||
|
# Page backwards
|
||||||
|
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||||
|
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True)
|
||||||
|
self.pagen -= 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Jump to page
|
||||||
|
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
|
||||||
|
async def jump_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Jump to page button.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
interaction, value = await input(
|
||||||
|
press,
|
||||||
|
title="Jump to page",
|
||||||
|
question="Page number to jump to"
|
||||||
|
)
|
||||||
|
value = value.strip()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not value.lstrip('- ').isdigit():
|
||||||
|
error = discord.Embed(title="Invalid page number, please try again!",
|
||||||
|
colour=discord.Colour.brand_red())
|
||||||
|
await interaction.response.send_message(embed=error, ephemeral=True)
|
||||||
|
else:
|
||||||
|
await interaction.response.defer(thinking=True)
|
||||||
|
pagen = int(value.lstrip('- '))
|
||||||
|
if value.startswith('-'):
|
||||||
|
pagen = -1 * pagen
|
||||||
|
elif pagen > 0:
|
||||||
|
pagen = pagen - 1
|
||||||
|
self.pagen = pagen
|
||||||
|
await self.refresh(thinking=interaction)
|
||||||
|
|
||||||
|
async def jump_button_refresh(self):
|
||||||
|
component = self.jump_button
|
||||||
|
component.label = f"{self.pagen + 1}/{self.page_count}"
|
||||||
|
component.disabled = (self.page_count <= 1)
|
||||||
|
|
||||||
|
# Page forwards
|
||||||
|
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||||
|
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True)
|
||||||
|
self.pagen += 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Quit
|
||||||
|
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||||
|
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Quit the UI.
|
||||||
|
"""
|
||||||
|
await press.response.defer()
|
||||||
|
# if self.child_viewer:
|
||||||
|
# await self.child_viewer.quit()
|
||||||
|
await self.quit()
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
files = []
|
||||||
|
embeds = []
|
||||||
|
for doc in self.current_page:
|
||||||
|
try:
|
||||||
|
imagedata = doc.to_bytes()
|
||||||
|
imagedata.seek(0)
|
||||||
|
except binascii.Error:
|
||||||
|
continue
|
||||||
|
fn = f"doc-{doc.document_id}.png"
|
||||||
|
file = discord.File(imagedata, fn)
|
||||||
|
embed = discord.Embed()
|
||||||
|
embed.set_image(url=f"attachment://{fn}")
|
||||||
|
files.append(file)
|
||||||
|
embeds.append(embed)
|
||||||
|
|
||||||
|
if not embeds:
|
||||||
|
embed = discord.Embed(description="You don't have any documents yet!")
|
||||||
|
embeds.append(embed)
|
||||||
|
|
||||||
|
print(f"FILES: {files}")
|
||||||
|
|
||||||
|
return MessageArgs(files=files, embeds=embeds)
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
to_refresh = (
|
||||||
|
self.jump_button_refresh(),
|
||||||
|
)
|
||||||
|
await asyncio.gather(*to_refresh)
|
||||||
|
if self.page_count > 1:
|
||||||
|
page_line = (
|
||||||
|
self.prev_button,
|
||||||
|
self.jump_button,
|
||||||
|
self.quit_button,
|
||||||
|
self.next_button,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
page_line = (self.quit_button,)
|
||||||
|
|
||||||
|
self.set_layout(page_line)
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
docs = await self.data.Document.fetch_where(self.filter).order_by('created_at', ORDER.DESC)
|
||||||
|
blocks = [
|
||||||
|
docs[i:i+self.block_len]
|
||||||
|
for i in range(0, len(docs), self.block_len)
|
||||||
|
]
|
||||||
|
self.blocks = blocks or [[]]
|
||||||
|
|
||||||
406
src/modules/dreamspace/ui/eventviewer.py
Normal file
406
src/modules/dreamspace/ui/eventviewer.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import asyncio
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ui.select import select, Select, UserSelect
|
||||||
|
from discord.ui.button import button, Button
|
||||||
|
from discord.ui.text_input import TextInput
|
||||||
|
from discord.enums import ButtonStyle, TextStyle
|
||||||
|
from discord.components import SelectOption
|
||||||
|
|
||||||
|
from datamodels import DataModel
|
||||||
|
from meta import LionBot, conf
|
||||||
|
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||||
|
from data import ORDER, Condition
|
||||||
|
|
||||||
|
from utils.ui import MessageUI, input
|
||||||
|
from utils.lib import MessageArgs, tabulate, utc_now
|
||||||
|
|
||||||
|
from .. import logger
|
||||||
|
|
||||||
|
|
||||||
|
class EventsUI(MessageUI):
|
||||||
|
block_len = 10
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
|
||||||
|
super().__init__(callerid=callerid, **kwargs)
|
||||||
|
|
||||||
|
self.bot = bot
|
||||||
|
self.data: DataModel = bot.core.datamodel
|
||||||
|
self.filter = Condition
|
||||||
|
|
||||||
|
# Paging state
|
||||||
|
self._pagen = 0
|
||||||
|
self.blocks = [[]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def page_count(self):
|
||||||
|
return len(self.blocks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pagen(self):
|
||||||
|
self._pagen %= self.page_count
|
||||||
|
return self._pagen
|
||||||
|
|
||||||
|
@pagen.setter
|
||||||
|
def pagen(self, value):
|
||||||
|
self._pagen = value % self.page_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_page(self):
|
||||||
|
return self.blocks[self.pagen]
|
||||||
|
|
||||||
|
# ----- UI Components -----
|
||||||
|
|
||||||
|
# Page backwards
|
||||||
|
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||||
|
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True)
|
||||||
|
self.pagen -= 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Jump to page
|
||||||
|
|
||||||
|
# Page forwards
|
||||||
|
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||||
|
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True)
|
||||||
|
self.pagen += 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Quit
|
||||||
|
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||||
|
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Quit the UI.
|
||||||
|
"""
|
||||||
|
await press.response.defer()
|
||||||
|
# if self.child_viewer:
|
||||||
|
# await self.child_viewer.quit()
|
||||||
|
await self.quit()
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TicketListUI(MessageUI):
|
||||||
|
# Select Ticket
|
||||||
|
@select(
|
||||||
|
cls=Select,
|
||||||
|
placeholder="TICKETS_MENU_PLACEHOLDER",
|
||||||
|
min_values=1, max_values=1
|
||||||
|
)
|
||||||
|
async def tickets_menu(self, selection: discord.Interaction, selected: Select):
|
||||||
|
await selection.response.defer(thinking=True, ephemeral=True)
|
||||||
|
if selected.values:
|
||||||
|
ticketid = int(selected.values[0])
|
||||||
|
ticket = await Ticket.fetch_ticket(self.bot, ticketid)
|
||||||
|
ticketui = TicketUI(self.bot, ticket, self._callerid)
|
||||||
|
if self.child_ticket:
|
||||||
|
await self.child_ticket.quit()
|
||||||
|
self.child_ticket = ticketui
|
||||||
|
await ticketui.run(selection)
|
||||||
|
|
||||||
|
async def tickets_menu_refresh(self):
|
||||||
|
menu = self.tickets_menu
|
||||||
|
t = self.bot.translator.t
|
||||||
|
menu.placeholder = t(_p(
|
||||||
|
'ui:tickets|menu:tickets|placeholder',
|
||||||
|
"Select Ticket"
|
||||||
|
))
|
||||||
|
options = []
|
||||||
|
for ticket in self.current_page:
|
||||||
|
option = SelectOption(
|
||||||
|
label=f"Ticket #{ticket.data.guild_ticketid}",
|
||||||
|
value=str(ticket.data.ticketid)
|
||||||
|
)
|
||||||
|
options.append(option)
|
||||||
|
menu.options = options
|
||||||
|
|
||||||
|
# Backwards
|
||||||
|
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
|
||||||
|
async def prev_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True, ephemeral=True)
|
||||||
|
self.pagen -= 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Jump to page
|
||||||
|
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
|
||||||
|
async def jump_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Jump-to-page button.
|
||||||
|
Loads a page-switch dialogue.
|
||||||
|
"""
|
||||||
|
t = self.bot.translator.t
|
||||||
|
try:
|
||||||
|
interaction, value = await input(
|
||||||
|
press,
|
||||||
|
title=t(_p(
|
||||||
|
'ui:tickets|button:jump|input:title',
|
||||||
|
"Jump to page"
|
||||||
|
)),
|
||||||
|
question=t(_p(
|
||||||
|
'ui:tickets|button:jump|input:question',
|
||||||
|
"Page number to jump to"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
value = value.strip()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not value.lstrip('- ').isdigit():
|
||||||
|
error_embed = discord.Embed(
|
||||||
|
title=t(_p(
|
||||||
|
'ui:tickets|button:jump|error:invalid_page',
|
||||||
|
"Invalid page number, please try again!"
|
||||||
|
)),
|
||||||
|
colour=discord.Colour.brand_red()
|
||||||
|
)
|
||||||
|
await interaction.response.send_message(embed=error_embed, ephemeral=True)
|
||||||
|
else:
|
||||||
|
await interaction.response.defer(thinking=True)
|
||||||
|
pagen = int(value.lstrip('- '))
|
||||||
|
if value.startswith('-'):
|
||||||
|
pagen = -1 * pagen
|
||||||
|
elif pagen > 0:
|
||||||
|
pagen = pagen - 1
|
||||||
|
self.pagen = pagen
|
||||||
|
await self.refresh(thinking=interaction)
|
||||||
|
|
||||||
|
async def jump_button_refresh(self):
|
||||||
|
component = self.jump_button
|
||||||
|
component.label = f"{self.pagen + 1}/{self.page_count}"
|
||||||
|
component.disabled = (self.page_count <= 1)
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
|
||||||
|
async def next_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
await press.response.defer(thinking=True)
|
||||||
|
self.pagen += 1
|
||||||
|
await self.refresh(thinking=press)
|
||||||
|
|
||||||
|
# Quit
|
||||||
|
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||||
|
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Quit the UI.
|
||||||
|
"""
|
||||||
|
await press.response.defer()
|
||||||
|
if self.child_ticket:
|
||||||
|
await self.child_ticket.quit()
|
||||||
|
await self.quit()
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
def _format_ticket(self, ticket) -> str:
|
||||||
|
"""
|
||||||
|
Format a ticket into a single embed line.
|
||||||
|
"""
|
||||||
|
components = (
|
||||||
|
"[#{ticketid}]({link})",
|
||||||
|
"{created}",
|
||||||
|
"`{type}[{state}]`",
|
||||||
|
"<@{targetid}>",
|
||||||
|
"{content}",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatstr = ' | '.join(components)
|
||||||
|
|
||||||
|
data = ticket.data
|
||||||
|
if not data.content:
|
||||||
|
content = 'No Content'
|
||||||
|
elif len(data.content) > 100:
|
||||||
|
content = data.content[:97] + '...'
|
||||||
|
else:
|
||||||
|
content = data.content
|
||||||
|
|
||||||
|
ticketstr = formatstr.format(
|
||||||
|
ticketid=data.guild_ticketid,
|
||||||
|
link=ticket.jump_url or 'https://lionbot.org',
|
||||||
|
created=discord.utils.format_dt(data.created_at, 'd'),
|
||||||
|
type=data.ticket_type.name,
|
||||||
|
state=data.ticket_state.name,
|
||||||
|
targetid=data.targetid,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
if data.ticket_state is TicketState.PARDONED:
|
||||||
|
ticketstr = f"~~{ticketstr}~~"
|
||||||
|
return ticketstr
|
||||||
|
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
t = self.bot.translator.t
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=t(_p(
|
||||||
|
'ui:tickets|embed|title',
|
||||||
|
"Moderation Tickets in {guild}"
|
||||||
|
)).format(guild=self.guild.name),
|
||||||
|
timestamp=utc_now()
|
||||||
|
)
|
||||||
|
tickets = self.current_page
|
||||||
|
if tickets:
|
||||||
|
desc = '\n'.join(self._format_ticket(ticket) for ticket in tickets)
|
||||||
|
else:
|
||||||
|
desc = t(_p(
|
||||||
|
'ui:tickets|embed|desc:no_tickets',
|
||||||
|
"No tickets matching the given criteria!"
|
||||||
|
))
|
||||||
|
embed.description = desc
|
||||||
|
|
||||||
|
filterstr = self.filters.formatted()
|
||||||
|
if filterstr:
|
||||||
|
embed.add_field(
|
||||||
|
name=t(_p(
|
||||||
|
'ui:tickets|embed|field:filters|name',
|
||||||
|
"Filters"
|
||||||
|
)),
|
||||||
|
value=filterstr,
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageArgs(embed=embed)
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
to_refresh = (
|
||||||
|
self.edit_filter_button_refresh(),
|
||||||
|
self.select_ticket_button_refresh(),
|
||||||
|
self.pardon_button_refresh(),
|
||||||
|
self.tickets_menu_refresh(),
|
||||||
|
self.filter_type_menu_refresh(),
|
||||||
|
self.filter_state_menu_refresh(),
|
||||||
|
self.filter_target_menu_refresh(),
|
||||||
|
self.jump_button_refresh(),
|
||||||
|
)
|
||||||
|
await asyncio.gather(*to_refresh)
|
||||||
|
|
||||||
|
action_line = (
|
||||||
|
self.edit_filter_button,
|
||||||
|
self.select_ticket_button,
|
||||||
|
self.pardon_button,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.page_count > 1:
|
||||||
|
page_line = (
|
||||||
|
self.prev_button,
|
||||||
|
self.jump_button,
|
||||||
|
self.quit_button,
|
||||||
|
self.next_button,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
page_line = ()
|
||||||
|
action_line = (*action_line, self.quit_button)
|
||||||
|
|
||||||
|
if self.show_filters:
|
||||||
|
menus = (
|
||||||
|
(self.filter_type_menu,),
|
||||||
|
(self.filter_state_menu,),
|
||||||
|
(self.filter_target_menu,),
|
||||||
|
)
|
||||||
|
elif self.show_tickets and self.current_page:
|
||||||
|
menus = ((self.tickets_menu,),)
|
||||||
|
else:
|
||||||
|
menus = ()
|
||||||
|
|
||||||
|
self.set_layout(
|
||||||
|
action_line,
|
||||||
|
*menus,
|
||||||
|
page_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
tickets = await Ticket.fetch_tickets(
|
||||||
|
self.bot,
|
||||||
|
*self.filters.conditions(),
|
||||||
|
guildid=self.guild.id,
|
||||||
|
)
|
||||||
|
blocks = [
|
||||||
|
tickets[i:i+self.block_len]
|
||||||
|
for i in range(0, len(tickets), self.block_len)
|
||||||
|
]
|
||||||
|
self.blocks = blocks or [[]]
|
||||||
|
|
||||||
|
|
||||||
|
class TicketUI(MessageUI):
|
||||||
|
def __init__(self, bot: LionBot, ticket: Ticket, callerid: int, **kwargs):
|
||||||
|
super().__init__(callerid=callerid, **kwargs)
|
||||||
|
|
||||||
|
self.bot = bot
|
||||||
|
self.ticket = ticket
|
||||||
|
|
||||||
|
# ----- API -----
|
||||||
|
|
||||||
|
# ----- UI Components -----
|
||||||
|
# Pardon Ticket
|
||||||
|
@button(
|
||||||
|
label="PARDON_BUTTON_PLACEHOLDER",
|
||||||
|
style=ButtonStyle.red
|
||||||
|
)
|
||||||
|
async def pardon_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
t = self.bot.translator.t
|
||||||
|
|
||||||
|
modal_title = t(_p(
|
||||||
|
'ui:ticket|button:pardon|modal:reason|title',
|
||||||
|
"Pardon Moderation Ticket"
|
||||||
|
))
|
||||||
|
input_field = TextInput(
|
||||||
|
label=t(_p(
|
||||||
|
'ui:ticket|button:pardon|modal:reason|field|label',
|
||||||
|
"Why are you pardoning this ticket?"
|
||||||
|
)),
|
||||||
|
style=TextStyle.long,
|
||||||
|
min_length=0,
|
||||||
|
max_length=1024,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
interaction, reason = await input(
|
||||||
|
press, modal_title, field=input_field, timeout=300,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise ResponseTimedOut
|
||||||
|
|
||||||
|
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
|
await self.ticket.pardon(modid=press.user.id, reason=reason)
|
||||||
|
await self.refresh(thinking=interaction)
|
||||||
|
|
||||||
|
|
||||||
|
async def pardon_button_refresh(self):
|
||||||
|
button = self.pardon_button
|
||||||
|
t = self.bot.translator.t
|
||||||
|
button.label = t(_p(
|
||||||
|
'ui:ticket|button:pardon|label',
|
||||||
|
"Pardon"
|
||||||
|
))
|
||||||
|
button.disabled = (self.ticket.data.ticket_state is TicketState.PARDONED)
|
||||||
|
|
||||||
|
# Quit
|
||||||
|
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
||||||
|
async def quit_button(self, press: discord.Interaction, pressed: Button):
|
||||||
|
"""
|
||||||
|
Quit the UI.
|
||||||
|
"""
|
||||||
|
await press.response.defer()
|
||||||
|
await self.quit()
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
return await self.ticket.make_message()
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
await self.pardon_button_refresh()
|
||||||
|
self.set_layout(
|
||||||
|
(self.pardon_button, self.quit_button,)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
await self.ticket.data.refresh()
|
||||||
0
src/modules/dreamspace/ui/specimen.py
Normal file
0
src/modules/dreamspace/ui/specimen.py
Normal file
@@ -21,6 +21,8 @@ from .data import ProfileData
|
|||||||
from .profile import UserProfile
|
from .profile import UserProfile
|
||||||
from .community import Community
|
from .community import Community
|
||||||
|
|
||||||
|
from .ui import TwitchLinkStatic, TwitchLinkFlow
|
||||||
|
|
||||||
|
|
||||||
class ProfileCog(LionCog):
|
class ProfileCog(LionCog):
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
@@ -34,6 +36,8 @@ class ProfileCog(LionCog):
|
|||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
|
|
||||||
|
self.bot.add_view(TwitchLinkStatic(timeout=None))
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -197,6 +201,16 @@ class ProfileCog(LionCog):
|
|||||||
community = await Community.create_from_twitch(self.bot, user)
|
community = await Community.create_from_twitch(self.bot, user)
|
||||||
return community
|
return community
|
||||||
|
|
||||||
|
# ----- Admin Commands -----
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='linkoffer',
|
||||||
|
description="Send a message with a permanent button for profile linking"
|
||||||
|
)
|
||||||
|
@appcmds.default_permissions(manage_guild=True)
|
||||||
|
async def linkoffer_cmd(self, ctx: LionContext):
|
||||||
|
view = TwitchLinkStatic(timeout=None)
|
||||||
|
await ctx.channel.send(embed=view.embed, view=view)
|
||||||
|
|
||||||
# ----- Profile Commands -----
|
# ----- Profile Commands -----
|
||||||
@cmds.hybrid_group(
|
@cmds.hybrid_group(
|
||||||
name='profiles',
|
name='profiles',
|
||||||
@@ -217,6 +231,13 @@ class ProfileCog(LionCog):
|
|||||||
description="Link a twitch account to your current profile."
|
description="Link a twitch account to your current profile."
|
||||||
)
|
)
|
||||||
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
flowui = TwitchLinkFlow(self.bot, ctx.author, callerid=ctx.author.id)
|
||||||
|
await flowui.run(ctx.interaction)
|
||||||
|
await flowui.wait()
|
||||||
|
|
||||||
|
async def old_profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||||
if not ctx.interaction:
|
if not ctx.interaction:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
1
src/modules/profiles/ui/__init__.py
Normal file
1
src/modules/profiles/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .twitchlink import TwitchLinkStatic, TwitchLinkFlow
|
||||||
337
src/modules/profiles/ui/twitchlink.py
Normal file
337
src/modules/profiles/ui/twitchlink.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
UI Views for Twitch linkage.
|
||||||
|
|
||||||
|
- Persistent view with interaction-button to enter link flow.
|
||||||
|
- We don't store the view, but we listen to interaction button id.
|
||||||
|
- Command to enter link flow.
|
||||||
|
|
||||||
|
For link flow, send ephemeral embed with instructions and what to expect, with link button below.
|
||||||
|
After auth is granted through OAuth flow (or if not granted, e.g. on timeout or failure)
|
||||||
|
edit the embed to reflect auth situation.
|
||||||
|
|
||||||
|
If migration occurred, add the migration text as a field to the embed.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
from discord.ui.button import button, Button
|
||||||
|
from discord.enums import ButtonStyle
|
||||||
|
from twitchAPI.helper import first
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
|
||||||
|
from utils.ui import MessageUI
|
||||||
|
from utils.lib import MessageArgs, utc_now
|
||||||
|
from utils.ui.leo import LeoUI
|
||||||
|
|
||||||
|
from modules.profiles.profile import UserProfile
|
||||||
|
|
||||||
|
import brand
|
||||||
|
|
||||||
|
from .. import logger
|
||||||
|
|
||||||
|
class TwitchLinkStatic(LeoUI):
|
||||||
|
"""
|
||||||
|
Static UI whose only job is to display a persistent button
|
||||||
|
to ask people to connect their twitch account.
|
||||||
|
"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._embed: Optional[discord.Embed] = None
|
||||||
|
print("INITIALISATION")
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self) -> discord.Embed:
|
||||||
|
"""
|
||||||
|
This is the persistent message people will see with the button that starts the Oauth flow.
|
||||||
|
|
||||||
|
Not sure what this should actually say, or whether it should be customisable via command.
|
||||||
|
|
||||||
|
:TODO-MARKER:
|
||||||
|
"""
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Link your Twitch account!",
|
||||||
|
description=(
|
||||||
|
"To participate in the Dreamspace Adventure Game :TM:, "
|
||||||
|
"please start by pressing the button below to begin the login flow with Twitch!"
|
||||||
|
),
|
||||||
|
colour=brand.ACCENT_COLOUR,
|
||||||
|
)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
@embed.setter
|
||||||
|
def embed(self, value):
|
||||||
|
self._embed = value
|
||||||
|
|
||||||
|
@button(label="Connect", custom_id="BTN-LINK-TWITCH", style=ButtonStyle.green, emoji='🔗')
|
||||||
|
@log_wrap(action="link-twitch-btn")
|
||||||
|
async def button_linker(self, interaction: discord.Interaction, btn: Button):
|
||||||
|
# Here we just reply to the interaction with the AuthFlow UI
|
||||||
|
# TODO
|
||||||
|
print("RESPONDING")
|
||||||
|
flowui = TwitchLinkFlow(interaction.client, interaction.user, callerid=interaction.user.id)
|
||||||
|
await flowui.run(interaction)
|
||||||
|
await flowui.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class FlowState(IntEnum):
|
||||||
|
SETUP = -1
|
||||||
|
WAITING = 0
|
||||||
|
|
||||||
|
CANCELLED = 1
|
||||||
|
TIMEOUT = 2
|
||||||
|
ERRORED = 3
|
||||||
|
|
||||||
|
WORKING = 9
|
||||||
|
DONE = 10
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchLinkFlow(MessageUI):
|
||||||
|
def __init__(self, bot: LionBot, caller: discord.User | discord.Member, *args, **kwargs):
|
||||||
|
kwargs.setdefault('callerid', caller.id)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
self._auth_task = None
|
||||||
|
self._stage: FlowState = FlowState.SETUP
|
||||||
|
self.flow = None
|
||||||
|
self.authrow = None
|
||||||
|
self.user = caller
|
||||||
|
|
||||||
|
self._info = None
|
||||||
|
self._migration_details = None
|
||||||
|
|
||||||
|
# ----- UI API -----
|
||||||
|
async def run(self, interaction: discord.Interaction, **kwargs):
|
||||||
|
await interaction.response.defer(ephemeral=True, thinking=True)
|
||||||
|
await self._start_flow()
|
||||||
|
await self.draw(interaction, **kwargs)
|
||||||
|
if self._stage is FlowState.ERRORED:
|
||||||
|
# This can happen if starting the flow failed
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@log_wrap(action="start-twitch-flow-ui")
|
||||||
|
async def _start_flow(self):
|
||||||
|
logger.info(f"Starting twitch authentication flow for {self.user}")
|
||||||
|
try:
|
||||||
|
self.flow = await self.bot.get_cog('TwitchAuthCog').start_auth()
|
||||||
|
except aiohttp.ClientError:
|
||||||
|
self._stage = FlowState.ERRORED
|
||||||
|
self._info = (
|
||||||
|
"Could not establish a connection to the authentication server! "
|
||||||
|
"Please try again later~"
|
||||||
|
)
|
||||||
|
logger.exception("Unexpected exception while starting authentication flow!", exc_info=True)
|
||||||
|
else:
|
||||||
|
self._stage = FlowState.WAITING
|
||||||
|
self._auth_task = asyncio.create_task(self._auth_flow())
|
||||||
|
|
||||||
|
@log_wrap(action="run-twitch-flow-ui")
|
||||||
|
async def _auth_flow(self):
|
||||||
|
"""
|
||||||
|
Run the flow and wait for a timeout, cancellation, or callback.
|
||||||
|
Update the message accordingly.
|
||||||
|
"""
|
||||||
|
assert self.flow is not None
|
||||||
|
try:
|
||||||
|
# TODO: Cancel this in cleanup
|
||||||
|
authrow = await asyncio.wait_for(self.flow.run(), timeout=60)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._stage = FlowState.TIMEOUT
|
||||||
|
# Link Timed Out!
|
||||||
|
self._info = (
|
||||||
|
"We didn't receive a response so we closed the uplink "
|
||||||
|
"to keep your account safe! If you still want to connect, please try again!"
|
||||||
|
)
|
||||||
|
await self.refresh()
|
||||||
|
await self.close()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Presumably the user exited or the bot is shutting down.
|
||||||
|
# Not safe to edit the message, but try and cleanup
|
||||||
|
await self.close()
|
||||||
|
except SafeCancellation as e:
|
||||||
|
logger.info("User or server cancelled authentication flow: ", exc_info=True)
|
||||||
|
# Uplink Cancelled!
|
||||||
|
self._info = (
|
||||||
|
f"We couldn't complete the uplink!\nReason:*{e.msg}*"
|
||||||
|
)
|
||||||
|
self._stage = FlowState.CANCELLED
|
||||||
|
await self.refresh()
|
||||||
|
await self.close()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Something unexpected went wrong while running the flow!")
|
||||||
|
else:
|
||||||
|
self._stage = FlowState.WORKING
|
||||||
|
self._info = (
|
||||||
|
"Authentication complete! Connecting your Dreamspace account ...."
|
||||||
|
)
|
||||||
|
self.authrow = authrow
|
||||||
|
await self.refresh()
|
||||||
|
await self._link_twitch(str(authrow.userid))
|
||||||
|
await self.refresh()
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
await super().cleanup()
|
||||||
|
if self._auth_task and not self._auth_task.cancelled():
|
||||||
|
self._auth_task.cancel()
|
||||||
|
|
||||||
|
async def _link_twitch(self, twitch_id: str):
|
||||||
|
"""
|
||||||
|
Link the caller's profile to the given twitch_id.
|
||||||
|
|
||||||
|
Performs migration if needed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
twitch_user = await first(self.bot.twitch.get_users(user_ids=[twitch_id]))
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Looking up user {self.authrow} from Twitch authentication flow raised an error."
|
||||||
|
)
|
||||||
|
self._stage = FlowState.ERRORED
|
||||||
|
self._info = "Failed to look up your user details from Twitch! Please try again later."
|
||||||
|
return
|
||||||
|
|
||||||
|
if twitch_user is None:
|
||||||
|
logger.error(
|
||||||
|
f"User {self.authrow} obtained from Twitch authentication does not exist."
|
||||||
|
)
|
||||||
|
self._stage = FlowState.ERRORED
|
||||||
|
self._info = "Authentication failed! Please try again later."
|
||||||
|
return
|
||||||
|
|
||||||
|
profiles = self.bot.get_cog('ProfileCog')
|
||||||
|
userid = self.user.id
|
||||||
|
|
||||||
|
caller_profile = await UserProfile.fetch_from_discordid(self.bot, userid)
|
||||||
|
twitch_profile = await UserProfile.fetch_from_twitchid(self.bot, twitch_id)
|
||||||
|
|
||||||
|
succ_info = (
|
||||||
|
f"Successfully established uplink to your Twitch account **{twitch_user.display_name}** "
|
||||||
|
"and transferred dreamspace data! Happy adventuring, and watch out for the grue~"
|
||||||
|
)
|
||||||
|
# ::TODO-MARKER::
|
||||||
|
|
||||||
|
if twitch_profile is None:
|
||||||
|
if caller_profile is None:
|
||||||
|
# Neither profile exists
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, self.user)
|
||||||
|
await profile.attach_twitch(twitch_id)
|
||||||
|
|
||||||
|
self._stage = FlowState.DONE
|
||||||
|
self._info = succ_info
|
||||||
|
else:
|
||||||
|
await caller_profile.attach_twitch(twitch_id)
|
||||||
|
|
||||||
|
self._stage = FlowState.DONE
|
||||||
|
self._info = succ_info
|
||||||
|
else:
|
||||||
|
if caller_profile is None:
|
||||||
|
await twitch_profile.attach_discord(self.user.id)
|
||||||
|
|
||||||
|
self._stage = FlowState.DONE
|
||||||
|
self._info = succ_info
|
||||||
|
elif twitch_profile.profileid == caller_profile.profileid:
|
||||||
|
self._stage = FlowState.CANCELLED
|
||||||
|
self._info = (
|
||||||
|
f"The Twitch account **{twitch_user.display_name}** is already linked to your profile!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# In this case we have conflicting profiles we need to migrate
|
||||||
|
try:
|
||||||
|
results = await profiles.migrate_profile(twitch_profile, caller_profile)
|
||||||
|
except Exception:
|
||||||
|
self._stage = FlowState.ERRORED
|
||||||
|
self._info = (
|
||||||
|
"An issue was encountered while merging your account profiles! "
|
||||||
|
"The migration was rolled back, and not data has been lost.\n"
|
||||||
|
"The developer has been notified, please try again later!"
|
||||||
|
)
|
||||||
|
logger.exception(f"Failed to migrate profiles {twitch_profile=} to {caller_profile=}")
|
||||||
|
else:
|
||||||
|
self._stage = FlowState.DONE
|
||||||
|
self._info = succ_info
|
||||||
|
self._migration_details = '\n'.join(results)
|
||||||
|
logger.info(
|
||||||
|
f"Migrated {twitch_profile=} to {caller_profile}. Info: {self._migration_details}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
if self._stage is FlowState.SETUP:
|
||||||
|
raise ValueError("Making message before flow initialisation!")
|
||||||
|
assert self.flow is not None
|
||||||
|
|
||||||
|
if self._stage is FlowState.WAITING:
|
||||||
|
# Message should be the initial request page
|
||||||
|
dur = discord.utils.format_dt(utc_now() + timedelta(seconds=60), style='R')
|
||||||
|
|
||||||
|
title = "Press the button to login!"
|
||||||
|
desc = (
|
||||||
|
"We have generated a custom secure link for you to connect your Twitch profile! "
|
||||||
|
"Press the button below and accept the connection in your browser, "
|
||||||
|
"and we will begin the transfer!\n"
|
||||||
|
f"(Note: The link expires {dur})"
|
||||||
|
)
|
||||||
|
colour = brand.ACCENT_COLOUR
|
||||||
|
elif self._stage is FlowState.CANCELLED:
|
||||||
|
# Show cancellation message
|
||||||
|
# Show 'you can close this'
|
||||||
|
title = "Uplink Cancelled!"
|
||||||
|
desc = self._info
|
||||||
|
colour = discord.Colour.brand_red()
|
||||||
|
elif self._stage is FlowState.TIMEOUT:
|
||||||
|
title = "Link Timed Out"
|
||||||
|
desc = self._info
|
||||||
|
colour = discord.Colour.brand_red()
|
||||||
|
elif self._stage is FlowState.ERRORED:
|
||||||
|
title = "Something went wrong!"
|
||||||
|
desc = self._info
|
||||||
|
colour = discord.Colour.brand_red()
|
||||||
|
elif self._stage is FlowState.WORKING:
|
||||||
|
# We've received the auth, we are now doing migration
|
||||||
|
title = "Establishing Connection"
|
||||||
|
desc = self._info
|
||||||
|
colour = brand.ACCENT_COLOUR
|
||||||
|
elif self._stage is FlowState.DONE:
|
||||||
|
title = "Success!"
|
||||||
|
desc = self._info
|
||||||
|
colour = discord.Colour.brand_green()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid stage value {self._stage}")
|
||||||
|
|
||||||
|
embed = discord.Embed(title=title, description=desc, colour=colour, timestamp=utc_now())
|
||||||
|
if self._migration_details:
|
||||||
|
embed.add_field(
|
||||||
|
name="Profile migration details",
|
||||||
|
value=self._migration_details
|
||||||
|
)
|
||||||
|
return MessageArgs(embed=embed)
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
# If we haven't received the auth callback yet, make the flow link button
|
||||||
|
if self.flow is None:
|
||||||
|
raise ValueError("Refreshing before flow initialisation!")
|
||||||
|
|
||||||
|
if self._stage <= FlowState.WAITING:
|
||||||
|
flow_link = self.flow.auth.return_auth_url()
|
||||||
|
button = Button(
|
||||||
|
style=ButtonStyle.link,
|
||||||
|
url=flow_link,
|
||||||
|
label="Login With Twitch"
|
||||||
|
)
|
||||||
|
self.set_layout((button,))
|
||||||
|
else:
|
||||||
|
self.set_layout(())
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
pass
|
||||||
16
src/routes/__init__.py
Normal file
16
src/routes/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from .stamps import routes as stamp_routes
|
||||||
|
from .documents import routes as doc_routes
|
||||||
|
from .users import routes as user_routes
|
||||||
|
from .specimens import routes as spec_routes
|
||||||
|
from .transactions import routes as txn_routes
|
||||||
|
from .events import routes as event_routes
|
||||||
|
from .lib import dbvar, datamodelsv, profiledatav
|
||||||
|
|
||||||
|
|
||||||
|
def register_routes(router):
|
||||||
|
router.add_routes(stamp_routes)
|
||||||
|
router.add_routes(doc_routes)
|
||||||
|
router.add_routes(user_routes)
|
||||||
|
router.add_routes(spec_routes)
|
||||||
|
router.add_routes(event_routes)
|
||||||
|
router.add_routes(txn_routes)
|
||||||
370
src/routes/documents.py
Normal file
370
src/routes/documents.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
- `/documents` with `POST, GET`
|
||||||
|
- `/documents/{document_id}` with `GET`, `PATCH`, `DELETE`
|
||||||
|
- `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import binascii
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, NamedTuple, Optional, Self, TypedDict, Unpack, reveal_type, List
|
||||||
|
from aiohttp import web
|
||||||
|
import discord
|
||||||
|
from data import Condition, condition
|
||||||
|
from data.queries import JOINTYPE
|
||||||
|
from datamodels import DataModel
|
||||||
|
from utils.lib import MessageArgs, tabulate
|
||||||
|
|
||||||
|
from .lib import ModelField, datamodelsv, event_log
|
||||||
|
from .stamps import Stamp, StampCreateParams, StampEditParams, StampPayload
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DocPayload(TypedDict):
|
||||||
|
document_id: int
|
||||||
|
document_data: str
|
||||||
|
seal: int
|
||||||
|
created_at: str
|
||||||
|
metadata: Optional[str]
|
||||||
|
stamps: List[StampPayload]
|
||||||
|
|
||||||
|
|
||||||
|
class DocCreateParamsReq(TypedDict, total=True):
|
||||||
|
document_data: str
|
||||||
|
seal: int
|
||||||
|
|
||||||
|
|
||||||
|
class DocCreateParams(DocCreateParamsReq, total=False):
|
||||||
|
metadata: Optional[str]
|
||||||
|
stamps: List[StampCreateParams]
|
||||||
|
|
||||||
|
|
||||||
|
class DocEditParams(TypedDict, total=False):
|
||||||
|
document_data: str
|
||||||
|
seal: int
|
||||||
|
metadata: Optional[str]
|
||||||
|
stamps: List[StampCreateParams]
|
||||||
|
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
ModelField('document_id', int, False, False, False),
|
||||||
|
ModelField('document_data', str, True, True, True),
|
||||||
|
ModelField('seal', int, True, True, True),
|
||||||
|
ModelField('created_at', str, False, False, False),
|
||||||
|
ModelField('metadata', Optional[str], False, True, True),
|
||||||
|
ModelField('stamps', List[StampCreateParams], False, True, True),
|
||||||
|
]
|
||||||
|
req_fields = {field.name for field in fields if field.required}
|
||||||
|
edit_fields = {field.name for field in fields if field.can_edit}
|
||||||
|
create_fields = {field.name for field in fields if field.can_create}
|
||||||
|
|
||||||
|
|
||||||
|
class Document:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.Document):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
self.row = row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def validate_create_params(cls, params):
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to document creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Document params missing required key '{missing}'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, document_id: int) -> Optional[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
row = await data.Document.fetch(document_id)
|
||||||
|
return cls(app, row) if row is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
document_id: Optional[int] = None,
|
||||||
|
seal: Optional[int] = None,
|
||||||
|
created_before: Optional[str] = None,
|
||||||
|
created_after: Optional[str] = None,
|
||||||
|
metadata: Optional[str] = None,
|
||||||
|
stamp_type: Optional[str] = None,
|
||||||
|
) -> List[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
Doc = data.Document
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
if document_id is not None:
|
||||||
|
conds.append(Doc.document_id == int(document_id))
|
||||||
|
if seal is not None:
|
||||||
|
conds.append(Doc.seal == int(seal))
|
||||||
|
if created_before is not None:
|
||||||
|
cbefore = datetime.fromisoformat(created_before)
|
||||||
|
conds.append(Doc.created_at <= cbefore)
|
||||||
|
if created_after is not None:
|
||||||
|
cafter = datetime.fromisoformat(created_after)
|
||||||
|
conds.append(Doc.created_at >= cafter)
|
||||||
|
if metadata is not None:
|
||||||
|
conds.append(Doc.metadata == metadata)
|
||||||
|
|
||||||
|
query = data.Document.table.fetch_rows_where(*conds)
|
||||||
|
# results = await query
|
||||||
|
|
||||||
|
# query = data.Document.table.select_where(*conds)
|
||||||
|
if stamp_type is not None:
|
||||||
|
query.join('document_stamps', using=('document_id',), join_type=JOINTYPE.LEFT)
|
||||||
|
query.join(
|
||||||
|
'stamp_types',
|
||||||
|
on=(data.DocumentStamp.stamp_type == data.StampType.stamp_type_id)
|
||||||
|
)
|
||||||
|
query.where(data.StampType.stamp_type_name == stamp_type)
|
||||||
|
query.select(docid = "DISTINCT(document_id)")
|
||||||
|
query.with_no_adapter()
|
||||||
|
results = await query
|
||||||
|
ids = [result['docid'] for result in results]
|
||||||
|
if ids:
|
||||||
|
rows = await data.Document.table.fetch_rows_where(
|
||||||
|
document_id=ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = []
|
||||||
|
else:
|
||||||
|
rows = await query
|
||||||
|
|
||||||
|
return [cls(app, row) for row in sorted(rows, key=lambda row:row.created_at)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, app: web.Application, **kwargs: Unpack[DocCreateParams]) -> Self:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
|
||||||
|
document_data = kwargs['document_data']
|
||||||
|
seal = kwargs['seal']
|
||||||
|
stamp_params = kwargs.get('stamps', [])
|
||||||
|
metadata = kwargs.get('metadata')
|
||||||
|
|
||||||
|
# Build the document first
|
||||||
|
row = await data.Document.create(
|
||||||
|
document_data=document_data,
|
||||||
|
seal=seal,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then build the stamps
|
||||||
|
stamps = []
|
||||||
|
for stampdata in stamp_params:
|
||||||
|
stampdata.setdefault('document_id', row.document_id)
|
||||||
|
stamp = await Stamp.create(app, **stampdata)
|
||||||
|
stamps.append(stamp)
|
||||||
|
|
||||||
|
self = cls(app, row)
|
||||||
|
# await self.log_create()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def get_stamps(self) -> List[Stamp]:
|
||||||
|
stamprows = await self.data.DocumentStamp.table.fetch_rows_where(document_id=self.row.document_id).order_by('stamp_id')
|
||||||
|
return [Stamp(self.app, row) for row in stamprows]
|
||||||
|
|
||||||
|
async def log_create(self, with_image=True):
|
||||||
|
args = await self.event_log_args(with_image=with_image)
|
||||||
|
args.kwargs['embed'].title = f"Document #{self.row.document_id} Created!"
|
||||||
|
try:
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
except discord.HTTPException:
|
||||||
|
if with_image:
|
||||||
|
# Try again without the image in case that was the issue
|
||||||
|
await self.log_create(with_image=False)
|
||||||
|
|
||||||
|
async def log_edit(self, with_image=True):
|
||||||
|
args = await self.event_log_args(with_image=with_image)
|
||||||
|
args.kwargs['embed'].title = f"Document #{self.row.document_id} Updated!"
|
||||||
|
try:
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
except discord.HTTPException:
|
||||||
|
if with_image:
|
||||||
|
# Try again without the image in case that was the issue
|
||||||
|
await self.log_create(with_image=False)
|
||||||
|
|
||||||
|
async def event_log_args(self, with_image=True) -> MessageArgs:
|
||||||
|
desc = '\n'.join(await self.tabulate())
|
||||||
|
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
|
||||||
|
embed.set_footer(text='Created At')
|
||||||
|
args: dict = {'embed': embed}
|
||||||
|
|
||||||
|
if with_image:
|
||||||
|
try:
|
||||||
|
imagedata = self.row.to_bytes()
|
||||||
|
imagedata.seek(0)
|
||||||
|
embed.set_image(url='attachment://document.png')
|
||||||
|
args['files'] = [discord.File(imagedata, "document.png")]
|
||||||
|
except binascii.Error:
|
||||||
|
# Could not decode base64
|
||||||
|
embed.add_field(name='Image', value="Could not decode document data!")
|
||||||
|
else:
|
||||||
|
embed.add_field(name='Image', value="Failed to send image!")
|
||||||
|
return MessageArgs(**args)
|
||||||
|
|
||||||
|
async def tabulate(self):
|
||||||
|
"""
|
||||||
|
Present the Document as a discord-readable table.
|
||||||
|
"""
|
||||||
|
stamps = await self.get_stamps()
|
||||||
|
typnames = []
|
||||||
|
if stamps:
|
||||||
|
typs = {stamp.row.stamp_type for stamp in stamps}
|
||||||
|
for typ in typs:
|
||||||
|
# Stamp types should be cached so this isn't expensive
|
||||||
|
typrow = await self.data.StampType.fetch(typ)
|
||||||
|
typnames.append(typrow.stamp_type_name)
|
||||||
|
|
||||||
|
table = {
|
||||||
|
'document_id': f"`{self.row.document_id}`",
|
||||||
|
'seal': str(self.row.seal),
|
||||||
|
'metadata': f"`{self.row.metadata}`" if self.row.metadata else "No metadata",
|
||||||
|
'stamps': ', '.join(f"`{name}`" for name in typnames) if typnames else "No stamps",
|
||||||
|
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||||
|
}
|
||||||
|
return tabulate(*table.items())
|
||||||
|
|
||||||
|
|
||||||
|
async def prepare(self) -> DocPayload:
|
||||||
|
stamps = await self.get_stamps()
|
||||||
|
|
||||||
|
results: DocPayload = {
|
||||||
|
'document_id': self.row.document_id,
|
||||||
|
'document_data': self.row.document_data,
|
||||||
|
'seal': self.row.seal,
|
||||||
|
'created_at': self.row.created_at.isoformat(),
|
||||||
|
'metadata': self.row.metadata,
|
||||||
|
'stamps': [await stamp.prepare() for stamp in stamps]
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def edit(self, **kwargs: Unpack[DocEditParams]):
|
||||||
|
data = self.data
|
||||||
|
row = self.row
|
||||||
|
# Update the row data
|
||||||
|
# If stamps are given, delete the existing ones
|
||||||
|
# Then write in the new ones.
|
||||||
|
update_args = {}
|
||||||
|
for key in {'document_data', 'seal', 'metadata'}:
|
||||||
|
if key in kwargs:
|
||||||
|
update_args[key] = kwargs[key]
|
||||||
|
if update_args:
|
||||||
|
await self.row.update(**update_args)
|
||||||
|
|
||||||
|
# TODO: Should really be in a transaction
|
||||||
|
# Actually each handler should be in a transaction
|
||||||
|
if new_stamps := kwargs.get('stamps', []):
|
||||||
|
await self.data.DocumentStamp.table.delete_where(document_id=self.row.document_id)
|
||||||
|
for stampdata in new_stamps:
|
||||||
|
stampdata.setdefault('document_id', row.document_id)
|
||||||
|
await Stamp.create(self.app, **stampdata)
|
||||||
|
await self.log_edit()
|
||||||
|
|
||||||
|
async def delete(self) -> DocPayload:
|
||||||
|
payload = await self.prepare()
|
||||||
|
await self.row.delete()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/documents')
|
||||||
|
@routes.view('/documents/')
|
||||||
|
class DocumentsView(web.View):
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
filter_params = {}
|
||||||
|
keys = [
|
||||||
|
'document_id',
|
||||||
|
'seal',
|
||||||
|
'created_before',
|
||||||
|
'created_after',
|
||||||
|
'metadata',
|
||||||
|
'stamp_type',
|
||||||
|
]
|
||||||
|
for key in keys:
|
||||||
|
if key in request.query:
|
||||||
|
filter_params[key] = request.query[key]
|
||||||
|
elif key in request:
|
||||||
|
filter_params[key] = request[key]
|
||||||
|
|
||||||
|
documents = await Document.query(request.app, **filter_params)
|
||||||
|
payload = [await doc.prepare() for doc in documents]
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
for key in create_fields:
|
||||||
|
if key in request:
|
||||||
|
params.setdefault(key, request[key])
|
||||||
|
|
||||||
|
await Document.validate_create_params(params)
|
||||||
|
|
||||||
|
document = await Document.create(self.request.app, **params)
|
||||||
|
payload = await document.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
@routes.view('/documents/{document_id}')
|
||||||
|
@routes.view('/documents/{document_id}/')
|
||||||
|
class DocumentView(web.View):
|
||||||
|
|
||||||
|
async def resolve_document(self):
|
||||||
|
request = self.request
|
||||||
|
document_id = request.match_info['document_id']
|
||||||
|
document = await Document.fetch_from_id(request.app, int(document_id))
|
||||||
|
if document is None:
|
||||||
|
raise web.HTTPNotFound(text="No document exists with the given ID.")
|
||||||
|
return document
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
doc = await self.resolve_document()
|
||||||
|
payload = await doc.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
doc = await self.resolve_document()
|
||||||
|
params = await self.request.json()
|
||||||
|
|
||||||
|
edit_data = {}
|
||||||
|
for key, value in params.items():
|
||||||
|
if key not in edit_fields:
|
||||||
|
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Document!")
|
||||||
|
edit_data[key] = value
|
||||||
|
|
||||||
|
for key in edit_fields:
|
||||||
|
if key in self.request:
|
||||||
|
edit_data.setdefault(key, self.request[key])
|
||||||
|
|
||||||
|
await doc.edit(**edit_data)
|
||||||
|
payload = await doc.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
doc = await self.resolve_document()
|
||||||
|
payload = await doc.delete()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
# We have one prefix route, /documents/{document_id}/stamps
|
||||||
|
@routes.route('*', "/documents/{document_id}{tail:/stamps}")
|
||||||
|
@routes.route('*', "/documents/{document_id}{tail:/stamps/.*}")
|
||||||
|
async def document_stamps_route(request: web.Request):
|
||||||
|
document_id = int(request.match_info['document_id'])
|
||||||
|
document = await Document.fetch_from_id(request.app, document_id)
|
||||||
|
if document is None:
|
||||||
|
raise web.HTTPNotFound(text="No document exists with the given ID.")
|
||||||
|
|
||||||
|
new_path = request.match_info['tail']
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path=} and setting {document_id=}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['document_id'] = document_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
match_info.current_app = request.app
|
||||||
|
new_request._match_info = match_info
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
|
||||||
505
src/routes/events.py
Normal file
505
src/routes/events.py
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
import binascii
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
|
||||||
|
from aiohttp import web
|
||||||
|
import discord
|
||||||
|
from data import Condition, condition
|
||||||
|
from data.conditions import NULL
|
||||||
|
from data.queries import JOINTYPE
|
||||||
|
from datamodels import DataModel, EventType
|
||||||
|
|
||||||
|
from modules.profiles.data import ProfileData
|
||||||
|
from utils.lib import MessageArgs, tabulate
|
||||||
|
|
||||||
|
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
|
||||||
|
from .specimens import Specimen, SpecimenPayload
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.EventDetails):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
|
||||||
|
self.row = row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, event_id: int):
|
||||||
|
data = app[datamodelsv]
|
||||||
|
row = await data.EventDetails.fetch(int(event_id))
|
||||||
|
return cls(app, row) if row is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
event_id: Optional[str] = None,
|
||||||
|
document_id: Optional[str] = None,
|
||||||
|
document_seal: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
user_name: Optional[str] = None,
|
||||||
|
occurred_before: Optional[str] = None,
|
||||||
|
occurred_after: Optional[str] = None,
|
||||||
|
created_before: Optional[str] = None,
|
||||||
|
created_after: Optional[str] = None,
|
||||||
|
event_type: Optional[str] = None,
|
||||||
|
) -> List[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
EventD = data.EventDetails
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
if event_id is not None:
|
||||||
|
conds.append(EventD.event_id == int(event_id))
|
||||||
|
if document_id is not None:
|
||||||
|
conds.append(EventD.document_id == int(document_id))
|
||||||
|
if document_seal is not None:
|
||||||
|
conds.append(EventD.document_seal == int(document_seal))
|
||||||
|
if user_id is not None:
|
||||||
|
conds.append(EventD.user_id == int(user_id))
|
||||||
|
if user_name is not None:
|
||||||
|
conds.append(EventD.user_name == user_name)
|
||||||
|
if created_before is not None:
|
||||||
|
cbefore = datetime.fromisoformat(created_before)
|
||||||
|
conds.append(EventD.created_at <= cbefore)
|
||||||
|
if created_after is not None:
|
||||||
|
cafter = datetime.fromisoformat(created_after)
|
||||||
|
conds.append(EventD.created_at >= cafter)
|
||||||
|
if occurred_before is not None:
|
||||||
|
before = datetime.fromisoformat(occurred_before)
|
||||||
|
conds.append(EventD.occurred_at <= before)
|
||||||
|
if occurred_after is not None:
|
||||||
|
after = datetime.fromisoformat(occurred_after)
|
||||||
|
conds.append(EventD.occurred_at >= after)
|
||||||
|
if event_type is not None:
|
||||||
|
ekey = (event_type.lower().strip(),)
|
||||||
|
if ekey not in [e.value for e in EventType]:
|
||||||
|
raise web.HTTPBadRequest(text=f"Unknown event type '{event_type}'")
|
||||||
|
conds.append(EventD.event_type == EventType(ekey))
|
||||||
|
|
||||||
|
rows = await EventD.fetch_where(*conds).order_by(EventD.occurred_at)
|
||||||
|
return [cls(app, row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def validate_create_params(cls, params):
|
||||||
|
if 'event_type' not in params:
|
||||||
|
raise web.HTTPBadRequest(text="Event creation missing required field 'event_type'.")
|
||||||
|
|
||||||
|
ekey = (params['event_type'].lower().strip(),)
|
||||||
|
if ekey not in [e.value for e in EventType]:
|
||||||
|
raise web.HTTPBadRequest(text=f"Unknown event type '{params['event_type']}'")
|
||||||
|
event_type = EventType(ekey)
|
||||||
|
|
||||||
|
req_fields = {
|
||||||
|
'user_name', 'occurred_at', 'event_type',
|
||||||
|
}
|
||||||
|
other_fields = {
|
||||||
|
'document_id', 'document',
|
||||||
|
'user_id', 'user',
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'user_id' not in params and 'user' not in params:
|
||||||
|
raise web.HTTPBadRequest(text="One of 'user_id' or 'user' must be supplied to create Event.")
|
||||||
|
|
||||||
|
match event_type:
|
||||||
|
case EventType.PLAIN:
|
||||||
|
req_fields.add('message')
|
||||||
|
case EventType.SUBSCRIBER:
|
||||||
|
req_fields.add('tier')
|
||||||
|
req_fields.add('subscribed_length')
|
||||||
|
other_fields.add('message')
|
||||||
|
case EventType.CHEER:
|
||||||
|
req_fields.add('amount')
|
||||||
|
other_fields.add('cheer_type')
|
||||||
|
other_fields.add('message')
|
||||||
|
case EventType.RAID:
|
||||||
|
req_fields.add('viewer_count')
|
||||||
|
|
||||||
|
create_fields = req_fields.union(other_fields)
|
||||||
|
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to {event_type} event creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"{event_type} Event params missing required key '{missing}'")
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, app: web.Application, **kwargs):
|
||||||
|
data = app[datamodelsv]
|
||||||
|
# EventD = data.EventDetails
|
||||||
|
|
||||||
|
ekey = (kwargs['event_type'].lower().strip(),)
|
||||||
|
if ekey not in [e.value for e in EventType]:
|
||||||
|
raise web.HTTPBadRequest(text=f"Unknown event type '{kwargs['event_type']}'")
|
||||||
|
event_type = EventType(ekey)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
typparams = {}
|
||||||
|
|
||||||
|
match event_type:
|
||||||
|
case EventType.PLAIN:
|
||||||
|
typtab = data.plain_events
|
||||||
|
typparams['message'] = kwargs['message']
|
||||||
|
case EventType.CHEER:
|
||||||
|
typtab = data.cheer_events
|
||||||
|
typparams['amount'] = kwargs['amount']
|
||||||
|
typparams['cheer_type'] = kwargs.get('cheer_type')
|
||||||
|
typparams['message'] = kwargs.get('message')
|
||||||
|
case EventType.RAID:
|
||||||
|
typtab = data.raid_events
|
||||||
|
typparams['visitor_count'] = kwargs.get('viewer_count')
|
||||||
|
case EventType.SUBSCRIBER:
|
||||||
|
typtab = data.subscriber_events
|
||||||
|
typparams['tier'] = kwargs['tier']
|
||||||
|
typparams['subscribed_length'] = kwargs['subscribed_length']
|
||||||
|
typparams['message'] = kwargs.get('message')
|
||||||
|
case _:
|
||||||
|
raise ValueError("Invalid EventType")
|
||||||
|
|
||||||
|
# TODO: This really really should be a transaction
|
||||||
|
|
||||||
|
# Create Document if required
|
||||||
|
if 'document' in kwargs:
|
||||||
|
from .documents import Document
|
||||||
|
doc_args = kwargs['document']
|
||||||
|
await Document.validate_create_params(doc_args)
|
||||||
|
doc = await Document.create(app, **doc_args)
|
||||||
|
document_id = doc.row.document_id
|
||||||
|
params['document_id'] = document_id
|
||||||
|
elif 'document_id' in kwargs:
|
||||||
|
document_id = kwargs['document_id']
|
||||||
|
params['document_id'] = document_id
|
||||||
|
|
||||||
|
# Create User if required
|
||||||
|
if 'user' in kwargs:
|
||||||
|
from .users import User
|
||||||
|
user_args = kwargs['user']
|
||||||
|
await User.validate_create_params(user_args)
|
||||||
|
user = await User.create(app, **user_args)
|
||||||
|
user_id = user.row.user_id
|
||||||
|
|
||||||
|
if 'user_id' in kwargs and not kwargs['user_id'] == user_id:
|
||||||
|
raise web.HTTPBadRequest(text="Provided 'user_id' does not match provided 'user'.")
|
||||||
|
else:
|
||||||
|
user_id = kwargs['user_id']
|
||||||
|
params['user_id'] = user_id
|
||||||
|
|
||||||
|
# Create Event row
|
||||||
|
params['event_type'] = event_type
|
||||||
|
params['user_name'] = kwargs['user_name']
|
||||||
|
params['occurred_at'] = datetime.fromisoformat(kwargs['occurred_at'])
|
||||||
|
|
||||||
|
eventrow = await data.Events.create(**params)
|
||||||
|
typparams['event_id'] = eventrow.event_id
|
||||||
|
|
||||||
|
# Create Event type row
|
||||||
|
typrow = await typtab.insert(**typparams)
|
||||||
|
|
||||||
|
details = await data.EventDetails.fetch(eventrow.event_id)
|
||||||
|
assert details is not None
|
||||||
|
self = cls(app, details)
|
||||||
|
await self.log_create()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def log_create(self, with_image=True):
|
||||||
|
args = await self.event_log_args(with_image=with_image)
|
||||||
|
args.kwargs['embed'].title = f"Event #{self.row.event_id} Created!"
|
||||||
|
try:
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
except discord.HTTPException:
|
||||||
|
if with_image:
|
||||||
|
# Try again without the image in case that was the issue
|
||||||
|
await self.log_create(with_image=False)
|
||||||
|
|
||||||
|
async def log_edit(self, with_image=True):
|
||||||
|
args = await self.event_log_args(with_image=with_image)
|
||||||
|
args.kwargs['embed'].title = f"Event #{self.row.event_id} Updated!"
|
||||||
|
try:
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
except discord.HTTPException:
|
||||||
|
if with_image:
|
||||||
|
# Try again without the image in case that was the issue
|
||||||
|
await self.log_create(with_image=False)
|
||||||
|
|
||||||
|
async def event_log_args(self, with_image=True) -> MessageArgs:
|
||||||
|
desc = '\n'.join(await self.tabulate())
|
||||||
|
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
|
||||||
|
embed.set_footer(text='Created At')
|
||||||
|
args: dict = {'embed': embed}
|
||||||
|
|
||||||
|
doc = await self.get_document()
|
||||||
|
if doc is not None:
|
||||||
|
embed.add_field(
|
||||||
|
name="Document",
|
||||||
|
value='\n'.join(await doc.tabulate()),
|
||||||
|
inline=False
|
||||||
|
)
|
||||||
|
if with_image:
|
||||||
|
try:
|
||||||
|
imagedata = doc.row.to_bytes()
|
||||||
|
imagedata.seek(0)
|
||||||
|
embed.set_image(url='attachment://document.png')
|
||||||
|
args['files'] = [discord.File(imagedata, "document.png")]
|
||||||
|
except binascii.Error:
|
||||||
|
# Could not decode base64
|
||||||
|
embed.add_field(name='Image', value="Could not decode document data!")
|
||||||
|
else:
|
||||||
|
embed.add_field(name='Image', value="Failed to send image!")
|
||||||
|
return MessageArgs(**args)
|
||||||
|
|
||||||
|
async def tabulate(self):
|
||||||
|
"""
|
||||||
|
Present the Event as a discord-readable table.
|
||||||
|
"""
|
||||||
|
user = await self.get_user()
|
||||||
|
assert user is not None
|
||||||
|
|
||||||
|
table = {
|
||||||
|
'event_id': f"`{self.row.event_id}`",
|
||||||
|
'event_type': f"`{self.row.event_type}`",
|
||||||
|
'user': f"`{self.row.user_id}` (`{self.row.user_name}`)",
|
||||||
|
'document': f"`{self.row.document_id}`",
|
||||||
|
'occurred_at': discord.utils.format_dt(self.row.occurred_at, 'F'),
|
||||||
|
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||||
|
}
|
||||||
|
info = self.row.event_type.info()
|
||||||
|
for col, param in zip(info.detailcolumns, info.params):
|
||||||
|
value = getattr(self.row, col)
|
||||||
|
table[param] = f"`{value}`"
|
||||||
|
return tabulate(*table.items())
|
||||||
|
|
||||||
|
async def edit(self, **kwargs):
|
||||||
|
data = self.data
|
||||||
|
# EventD = data.EventDetails
|
||||||
|
|
||||||
|
if 'event_type' in kwargs:
|
||||||
|
raise web.HTTPBadRequest(text="You cannot change the type of an event after creation.")
|
||||||
|
|
||||||
|
typparams = {}
|
||||||
|
|
||||||
|
match self.row.event_type:
|
||||||
|
case EventType.PLAIN:
|
||||||
|
typtab = data.plain_events
|
||||||
|
if 'message' in kwargs:
|
||||||
|
typparams['message'] = kwargs['message']
|
||||||
|
case EventType.CHEER:
|
||||||
|
typtab = data.cheer_events
|
||||||
|
for key in ('amount', 'cheer_type', 'message'):
|
||||||
|
if key in kwargs:
|
||||||
|
typparams[key] = kwargs[key]
|
||||||
|
case EventType.RAID:
|
||||||
|
typtab = data.raid_events
|
||||||
|
if 'viewer_count' in kwargs:
|
||||||
|
typparams['visitor_count'] = 'viewer_count'
|
||||||
|
case EventType.SUBSCRIBER:
|
||||||
|
typtab = data.subscriber_events
|
||||||
|
for key in ('tier', 'subscribed_length', 'message'):
|
||||||
|
if key in kwargs:
|
||||||
|
typparams[key] = kwargs[key]
|
||||||
|
if typparams:
|
||||||
|
await typtab.update_where(event_id=self.row.event_id).set(**typparams)
|
||||||
|
|
||||||
|
await self.log_edit()
|
||||||
|
await self.row.refresh()
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
payload = await self.prepare()
|
||||||
|
if self.row.document_id:
|
||||||
|
await self.data.Document.table.delete_where(document_id=self.row.document_id)
|
||||||
|
await self.data.Events.table.delete_where(event_id=self.row.event_id)
|
||||||
|
await self.row.refresh()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def get_user(self):
|
||||||
|
from .users import User
|
||||||
|
return await User.fetch_from_id(self.app, self.row.user_id)
|
||||||
|
|
||||||
|
async def get_document(self):
|
||||||
|
from .documents import Document
|
||||||
|
if self.row.document_id:
|
||||||
|
return await Document.fetch_from_id(self.app, self.row.document_id)
|
||||||
|
|
||||||
|
async def prepare(self):
|
||||||
|
row = await self.row.refresh()
|
||||||
|
assert row is not None
|
||||||
|
data = self.data
|
||||||
|
|
||||||
|
user = await self.get_user()
|
||||||
|
assert user is not None
|
||||||
|
document = await self.get_document()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'event_id': self.row.event_id,
|
||||||
|
'document_id': self.row.document_id,
|
||||||
|
'document': await document.prepare() if document else None,
|
||||||
|
'user_id': self.row.user_id,
|
||||||
|
'user': await user.prepare(),
|
||||||
|
'user_name': self.row.user_name,
|
||||||
|
'occurred_at': self.row.occurred_at.isoformat(),
|
||||||
|
'created_at': self.row.created_at.isoformat(),
|
||||||
|
'event_type': self.row.event_type.value[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
match row.event_type:
|
||||||
|
case EventType.PLAIN:
|
||||||
|
payload['message'] = row.plain_message
|
||||||
|
case EventType.SUBSCRIBER:
|
||||||
|
payload['tier'] = row.subscriber_tier
|
||||||
|
payload['subscribed_length'] = row.subscriber_length
|
||||||
|
payload['message'] = row.subscriber_message
|
||||||
|
case EventType.CHEER:
|
||||||
|
payload['amount'] = row.cheer_amount
|
||||||
|
payload['cheer_type'] = row.cheer_type
|
||||||
|
payload['message'] = row.cheer_message
|
||||||
|
case EventType.RAID:
|
||||||
|
payload['viewer_count'] = row.raid_visitor_count
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/events')
|
||||||
|
@routes.view('/events/', name='events')
|
||||||
|
class EventsView(web.View):
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
if 'user_id' in request:
|
||||||
|
params.setdefault('user_id', request['user_id'])
|
||||||
|
|
||||||
|
await Event.validate_create_params(params)
|
||||||
|
logger.info(f"Creating a new event with args: {params=}")
|
||||||
|
event = await Event.create(self.request.app, **params)
|
||||||
|
logger.debug(f"Created event: {event!r}")
|
||||||
|
payload = await event.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
filter_params = {}
|
||||||
|
keys = [
|
||||||
|
'event_id', 'document_id', 'document_seal',
|
||||||
|
'user_id', 'user_name', 'occurred_before', 'occurred_after',
|
||||||
|
'created_before', 'created_after', 'event_type',
|
||||||
|
]
|
||||||
|
for key in keys:
|
||||||
|
value = request.query.get(key, request.get(key, None))
|
||||||
|
filter_params[key] = value
|
||||||
|
|
||||||
|
logger.info(f"Querying events with params: {filter_params=}")
|
||||||
|
events = await Event.query(request.app, **filter_params)
|
||||||
|
payload = [await event.prepare() for event in events]
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/events/{event_id}')
|
||||||
|
@routes.view('/events/{event_id}/', name='event')
|
||||||
|
class EventView(web.View):
|
||||||
|
async def resolve_event(self):
|
||||||
|
request = self.request
|
||||||
|
event_id = request.match_info['event_id']
|
||||||
|
event = await Event.fetch_from_id(request.app, int(event_id))
|
||||||
|
if event is None:
|
||||||
|
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
event = await self.resolve_event()
|
||||||
|
logger.info(f"Received GET for event {event=}")
|
||||||
|
payload = await event.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
event = await self.resolve_event()
|
||||||
|
params = await self.request.json()
|
||||||
|
|
||||||
|
edit_data = {}
|
||||||
|
edit_fields = {'message', 'amount', 'cheer_type', 'viewer_count', 'tier', 'subscriber_length', 'message'}
|
||||||
|
for key, value in params.items():
|
||||||
|
if key not in edit_fields:
|
||||||
|
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
|
||||||
|
edit_data[key] = value
|
||||||
|
|
||||||
|
for key in edit_fields:
|
||||||
|
if key in self.request:
|
||||||
|
edit_data.setdefault(key, self.request[key])
|
||||||
|
|
||||||
|
logger.info(f"Received PATCH for event {event} with params: {params}")
|
||||||
|
await event.edit(**edit_data)
|
||||||
|
payload = await event.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
event = await self.resolve_event()
|
||||||
|
logger.info(f"Received DELETE for event {event}")
|
||||||
|
payload = await event.delete()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/events/{event_id}/user")
|
||||||
|
@routes.route('*', "/events/{event_id}/user{tail:/.*}")
|
||||||
|
async def event_user_route(request: web.Request):
|
||||||
|
event_id = int(request.match_info['event_id'])
|
||||||
|
event = await Event.fetch_from_id(request.app, event_id)
|
||||||
|
if event is None:
|
||||||
|
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||||
|
|
||||||
|
tail = request.match_info.get('tail', '')
|
||||||
|
new_path = "/users/{user_id}".format(user_id=event.row.user_id) + tail
|
||||||
|
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = event.row.user_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/events/{event_id}/document")
|
||||||
|
@routes.route('*', "/events/{event_id}/document{tail:/.*}")
|
||||||
|
async def event_document_route(request: web.Request):
|
||||||
|
event_id = int(request.match_info['event_id'])
|
||||||
|
event = await Event.fetch_from_id(request.app, event_id)
|
||||||
|
if event is None:
|
||||||
|
raise web.HTTPNotFound(text="No event exists with the given ID.")
|
||||||
|
|
||||||
|
tail = request.match_info.get('tail', '')
|
||||||
|
|
||||||
|
document = await event.get_document()
|
||||||
|
if document is None:
|
||||||
|
if request.method == 'POST' and not tail:
|
||||||
|
new_path = '/documents'
|
||||||
|
logger.info(f"Redirecting {request=} to POST /documents")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['event_id'] = event_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
raise web.HTTPNotFound(text="This event has no document.")
|
||||||
|
else:
|
||||||
|
document_id = document.row.document_id
|
||||||
|
# Redirect to POST /documents/{document_id}/...
|
||||||
|
new_path = f"/documents/{document_id}".format(document_id=document_id) + tail
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['event_id'] = event_id
|
||||||
|
new_request['document_id'] = document_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
29
src/routes/lib.py
Normal file
29
src/routes/lib.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import NamedTuple, Any, Optional, Self, Unpack, List, TypedDict
|
||||||
|
from aiohttp import web, ClientSession
|
||||||
|
from discord import Webhook
|
||||||
|
|
||||||
|
from data.database import Database
|
||||||
|
from datamodels import DataModel
|
||||||
|
from modules.profiles.data import ProfileData
|
||||||
|
from meta import conf
|
||||||
|
|
||||||
|
dbvar = web.AppKey("database", Database)
|
||||||
|
datamodelsv = web.AppKey("datamodels", DataModel)
|
||||||
|
profiledatav = web.AppKey("profiledata", ProfileData)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelField(NamedTuple):
|
||||||
|
name: str
|
||||||
|
typ: Any
|
||||||
|
required: bool
|
||||||
|
can_create: bool
|
||||||
|
can_edit: bool
|
||||||
|
|
||||||
|
|
||||||
|
async def event_log(*args, **kwargs):
|
||||||
|
# Post the given message to the configured event log, if set
|
||||||
|
event_log_url = conf.api.get('EVENTLOG')
|
||||||
|
if event_log_url:
|
||||||
|
async with ClientSession() as session:
|
||||||
|
webhook = Webhook.from_url(event_log_url, session=session)
|
||||||
|
await webhook.send(**kwargs)
|
||||||
304
src/routes/specimens.py
Normal file
304
src/routes/specimens.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""
|
||||||
|
- `/specimens` with `GET` and `POST`
|
||||||
|
- `/specimens/{specimen_id}` with `PATCH` and `DELETE`
|
||||||
|
- `/specimens/{specimen_id}/owner` which is passed to `/users/{user_id}`
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
|
||||||
|
from aiohttp import web
|
||||||
|
from data import Condition, condition
|
||||||
|
from data.conditions import NULL
|
||||||
|
from data.queries import JOINTYPE
|
||||||
|
from datamodels import DataModel
|
||||||
|
|
||||||
|
from .lib import ModelField, datamodelsv, dbvar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .users import UserCreateParams, UserPayload, User
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecimenPayload(TypedDict):
|
||||||
|
specimen_id: int
|
||||||
|
owner_id: int
|
||||||
|
owner: 'UserPayload'
|
||||||
|
born_at: str
|
||||||
|
forgotten_at: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class SpecimenCreateParamsReq(TypedDict, total=True):
|
||||||
|
owner_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SpecimenCreateParams(SpecimenCreateParamsReq, total=False):
|
||||||
|
owner: 'UserCreateParams'
|
||||||
|
born_at: str
|
||||||
|
forgotten_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class SpecimenEditParams(TypedDict, total=False):
|
||||||
|
owner_id: int
|
||||||
|
forgotten_at: Optional[str]
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
ModelField('specimen_id', int, False, False, False),
|
||||||
|
ModelField('owner_id', int, False, True, True),
|
||||||
|
ModelField('owner', 'UserPayload', False, True, False),
|
||||||
|
ModelField('born_at', str, False, True, False),
|
||||||
|
ModelField('forgotten_at', str, False, True, True),
|
||||||
|
]
|
||||||
|
req_fields = {field.name for field in fields if field.required}
|
||||||
|
edit_fields = {field.name for field in fields if field.can_edit}
|
||||||
|
create_fields = {field.name for field in fields if field.can_create}
|
||||||
|
|
||||||
|
|
||||||
|
class Specimen:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.Specimen):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
self.row = row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def validate_create_params(cls, params):
|
||||||
|
if 'owner_id' not in params and 'owner' not in params:
|
||||||
|
raise web.HTTPBadRequest(text="One of 'owner' or 'owner_id' must be supplied to create Specimen.")
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to specimen creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Specimen params missing required key '{missing}'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, spec_id: int) -> Optional[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
row = await data.Specimen.fetch(int(spec_id))
|
||||||
|
return cls(app, row) if row is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
specimen_id: Optional[str] = None,
|
||||||
|
owner_id: Optional[str] = None,
|
||||||
|
born_after: Optional[str] = None,
|
||||||
|
born_before: Optional[str] = None,
|
||||||
|
forgotten: Optional[str] = None,
|
||||||
|
forgotten_after: Optional[str] = None,
|
||||||
|
forgotten_before: Optional[str] = None,
|
||||||
|
) -> List[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
Spec = data.Specimen
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
|
||||||
|
if specimen_id is not None:
|
||||||
|
conds.append(Spec.specimen_id == int(specimen_id))
|
||||||
|
if owner_id is not None:
|
||||||
|
conds.append(Spec.owner_id == int(owner_id))
|
||||||
|
if born_after is not None:
|
||||||
|
bafter = datetime.fromisoformat(born_after)
|
||||||
|
conds.append(Spec.born_at >= bafter)
|
||||||
|
if born_before is not None:
|
||||||
|
bbefore = datetime.fromisoformat(born_before)
|
||||||
|
conds.append(Spec.born_at <= bbefore)
|
||||||
|
if forgotten_after is not None:
|
||||||
|
fafter = datetime.fromisoformat(forgotten_after)
|
||||||
|
conds.append(Spec.forgotten_at >= fafter)
|
||||||
|
if forgotten_before is not None:
|
||||||
|
fbefore = datetime.fromisoformat(forgotten_before)
|
||||||
|
conds.append(Spec.forgotten_at <= fbefore)
|
||||||
|
if forgotten is not None:
|
||||||
|
if forgotten.lower() in ('1', 'true'):
|
||||||
|
conds.append(Spec.forgotten_at != NULL)
|
||||||
|
elif forgotten.lower() in ('0', 'false'):
|
||||||
|
conds.append(Spec.forgotten_at == NULL)
|
||||||
|
rows = await Spec.fetch_where(*conds).order_by(Spec.born_at)
|
||||||
|
return [cls(app, row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, app: web.Application, **kwargs: Unpack[SpecimenCreateParams]) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new specimen from the given data.
|
||||||
|
|
||||||
|
This will create the provided 'owner' if required.
|
||||||
|
"""
|
||||||
|
from .users import User
|
||||||
|
|
||||||
|
create_args = {}
|
||||||
|
|
||||||
|
if 'owner' in kwargs:
|
||||||
|
# Create owner and set owner_id
|
||||||
|
owner_args = kwargs['owner']
|
||||||
|
await User.validate_create_params(owner_args)
|
||||||
|
owner = await User.create(app, **owner_args)
|
||||||
|
owner_id = owner.row.user_id
|
||||||
|
|
||||||
|
if 'owner_id' in kwargs and not kwargs['owner_id'] == owner_id:
|
||||||
|
raise web.HTTPBadRequest(text="Provided `owner_id` does not match provided `owner`.")
|
||||||
|
else:
|
||||||
|
owner_id = int(kwargs['owner_id'])
|
||||||
|
create_args['owner_id'] = owner_id
|
||||||
|
|
||||||
|
if 'born_at' in kwargs:
|
||||||
|
create_args['born_at'] = datetime.fromisoformat(kwargs['born_at'])
|
||||||
|
if 'forgotten_at' in kwargs:
|
||||||
|
create_args['forgotten_at'] = datetime.fromisoformat(kwargs['forgotten_at'])
|
||||||
|
|
||||||
|
data = app[datamodelsv]
|
||||||
|
|
||||||
|
logger.info(f"Creating Specimen with {create_args=}")
|
||||||
|
|
||||||
|
row = await data.Specimen.create(**create_args)
|
||||||
|
return cls(app, row)
|
||||||
|
|
||||||
|
async def edit(self, **kwargs: Unpack[SpecimenEditParams]):
|
||||||
|
row = self.row
|
||||||
|
|
||||||
|
edit_args = {}
|
||||||
|
if 'owner_id' in kwargs:
|
||||||
|
edit_args['owner_id'] = kwargs['owner_id']
|
||||||
|
# TODO: We should probably check that the specified owner exists
|
||||||
|
if 'forgotten_at' in kwargs:
|
||||||
|
forg = kwargs['forgotten_at']
|
||||||
|
if forg is None:
|
||||||
|
# Allows unsetting the forgotten date
|
||||||
|
# This may error if the user already had a live specimen
|
||||||
|
edit_args['forgotten_at'] = None
|
||||||
|
else:
|
||||||
|
edit_args['forgotten_at'] = datetime.fromisoformat(forg)
|
||||||
|
|
||||||
|
if edit_args:
|
||||||
|
logger.info(f"Updating specimen {row=} with {kwargs}")
|
||||||
|
await row.update(**edit_args)
|
||||||
|
|
||||||
|
async def delete(self) -> SpecimenPayload:
|
||||||
|
payload = await self.prepare()
|
||||||
|
await self.row.delete()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def get_owner(self):
|
||||||
|
from .users import User
|
||||||
|
return await User.fetch_from_id(self.app, self.row.owner_id)
|
||||||
|
|
||||||
|
async def prepare(self) -> SpecimenPayload:
|
||||||
|
owner = await self.get_owner()
|
||||||
|
if owner is None:
|
||||||
|
raise ValueError("Specimen Owner does not exist! This should never happen!")
|
||||||
|
|
||||||
|
results: SpecimenPayload = {
|
||||||
|
'specimen_id': self.row.specimen_id,
|
||||||
|
'owner_id': self.row.owner_id,
|
||||||
|
'owner': await owner.prepare(),
|
||||||
|
'born_at': self.row.born_at.isoformat(),
|
||||||
|
'forgotten_at': self.row.forgotten_at.isoformat() if self.row.forgotten_at else None
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/specimens')
|
||||||
|
@routes.view('/specimens/', name='specimens')
|
||||||
|
class SpecimensView(web.View):
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
for key in create_fields:
|
||||||
|
if key in request:
|
||||||
|
params.setdefault(key, request[key])
|
||||||
|
|
||||||
|
await Specimen.validate_create_params(params)
|
||||||
|
logger.info(f"Creating a new Specimen with args: {params=}")
|
||||||
|
spec = await Specimen.create(self.request.app, **params)
|
||||||
|
logger.debug(f"Created specimen: {spec!r}")
|
||||||
|
|
||||||
|
payload = await spec.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
filter_params = {}
|
||||||
|
keys = [
|
||||||
|
'specimen_id', 'owner_id', 'born_after', 'born_before',
|
||||||
|
'forgotten', 'forgotten_after', 'forgotten_before'
|
||||||
|
]
|
||||||
|
for key in keys:
|
||||||
|
value = request.query.get(key, request.get(key, None))
|
||||||
|
filter_params[key] = value
|
||||||
|
|
||||||
|
logger.info(f"Querying specimens with params: {filter_params=}")
|
||||||
|
specs = await Specimen.query(request.app, **filter_params)
|
||||||
|
payload = [await spec.prepare() for spec in specs]
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/specimens/{specimen_id}')
|
||||||
|
@routes.view('/specimens/{specimen_id}/', name='specimen')
|
||||||
|
class SpecimenView(web.View):
|
||||||
|
async def resolve_specimen(self):
|
||||||
|
request = self.request
|
||||||
|
spec_id = request.match_info['specimen_id']
|
||||||
|
spec = await Specimen.fetch_from_id(request.app, int(spec_id))
|
||||||
|
|
||||||
|
if spec is None:
|
||||||
|
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
spec = await self.resolve_specimen()
|
||||||
|
logger.info(f"Received GET for specimen {spec=}")
|
||||||
|
payload = await spec.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
spec = await self.resolve_specimen()
|
||||||
|
params = await self.request.json()
|
||||||
|
|
||||||
|
edit_data = {}
|
||||||
|
for key, value in params.items():
|
||||||
|
if key not in edit_fields:
|
||||||
|
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Specimen!")
|
||||||
|
edit_data[key] = value
|
||||||
|
|
||||||
|
for key in edit_fields:
|
||||||
|
if key in self.request:
|
||||||
|
edit_data.setdefault(key, self.request[key])
|
||||||
|
|
||||||
|
logger.info(f"Received PATCH for specimen {spec} with params: {params}")
|
||||||
|
await spec.edit(**edit_data)
|
||||||
|
payload = await spec.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
spec = await self.resolve_specimen()
|
||||||
|
logger.info(f"Received DELETE for specimen {spec}")
|
||||||
|
payload = await spec.delete()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/specimens/{specimen_id}/owner")
|
||||||
|
@routes.route('*', "/specimens/{specimen_id}/owner{tail:/.*}")
|
||||||
|
async def specimen_owner_route(request: web.Request):
|
||||||
|
spec_id = int(request.match_info['specimen_id'])
|
||||||
|
spec = await Specimen.fetch_from_id(request.app, spec_id)
|
||||||
|
if spec is None:
|
||||||
|
raise web.HTTPNotFound(text="No specimen exists with the given ID.")
|
||||||
|
|
||||||
|
tail = request.match_info.get('tail', '')
|
||||||
|
new_path = "/users/{user_id}".format(user_id=spec.row.owner_id) + tail
|
||||||
|
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = spec.row.owner_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
262
src/routes/stamps.py
Normal file
262
src/routes/stamps.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
from typing import Any, NamedTuple, Optional, TypedDict, Unpack, reveal_type
|
||||||
|
from aiohttp import web
|
||||||
|
from data.database import Database
|
||||||
|
from datamodels import DataModel
|
||||||
|
|
||||||
|
from .lib import datamodelsv, ModelField
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
|
||||||
|
|
||||||
|
class StampPayload(TypedDict):
|
||||||
|
stamp_id: int
|
||||||
|
document_id: int
|
||||||
|
stamp_type: str
|
||||||
|
pos_x: int
|
||||||
|
pos_y: int
|
||||||
|
rotation: float
|
||||||
|
|
||||||
|
|
||||||
|
class StampCreateParamsReq(TypedDict, total=True):
|
||||||
|
document_id: int
|
||||||
|
stamp_type: str
|
||||||
|
pos_x: int
|
||||||
|
pos_y: int
|
||||||
|
rotation: float
|
||||||
|
|
||||||
|
|
||||||
|
class StampCreateParams(StampCreateParamsReq, total=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StampEditParams(TypedDict, total=False):
|
||||||
|
document_id: int
|
||||||
|
stamp_type: str
|
||||||
|
pos_x: int
|
||||||
|
pos_y: int
|
||||||
|
rotation: float
|
||||||
|
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
ModelField('stamp_id', int, False, False, False),
|
||||||
|
ModelField('document_id', int, True, True, True),
|
||||||
|
ModelField('stamp_type', int, True, True, True),
|
||||||
|
ModelField('pos_x', int, True, True, True),
|
||||||
|
ModelField('pos_y', int, True, True, True),
|
||||||
|
ModelField('rotation', float, True, True, True),
|
||||||
|
]
|
||||||
|
req_fields = {field.name for field in fields if field.required}
|
||||||
|
edit_fields = {field.name for field in fields if field.can_edit}
|
||||||
|
create_fields = {field.name for field in fields if field.can_create}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Stamp:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.DocumentStamp):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
self.row = row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, stamp_id: int) -> Optional['Stamp']:
|
||||||
|
stamp = await app[datamodelsv].DocumentStamp.fetch(stamp_id)
|
||||||
|
if stamp is None:
|
||||||
|
return None
|
||||||
|
return cls(app, stamp)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
stamp_id: Optional[int] = None,
|
||||||
|
document_id: Optional[int] = None,
|
||||||
|
stamp_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
data = app[datamodelsv]
|
||||||
|
|
||||||
|
query_args = {}
|
||||||
|
if stamp_id is not None:
|
||||||
|
query_args['stamp_id'] = int(stamp_id)
|
||||||
|
if document_id is not None:
|
||||||
|
query_args['document_id'] = int(document_id)
|
||||||
|
if stamp_type is not None:
|
||||||
|
typerows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||||
|
typeids = [row.stamp_type_id for row in typerows]
|
||||||
|
if not typeids:
|
||||||
|
return []
|
||||||
|
query_args['stamp_type'] = typeids
|
||||||
|
results = await data.DocumentStamp.table.fetch_rows_where(**query_args)
|
||||||
|
return [cls(app, row) for row in sorted(results, key=lambda row:row.stamp_id)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
**kwargs: Unpack[StampCreateParams]
|
||||||
|
):
|
||||||
|
data = app[datamodelsv]
|
||||||
|
stamp_type = kwargs['stamp_type']
|
||||||
|
# Get the stamp_type
|
||||||
|
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||||
|
if not rows:
|
||||||
|
# Create the stamp type
|
||||||
|
row = await data.StampType.create(stamp_type_name=stamp_type)
|
||||||
|
else:
|
||||||
|
row = rows[0]
|
||||||
|
|
||||||
|
stamprow = await data.DocumentStamp.create(
|
||||||
|
document_id=kwargs['document_id'],
|
||||||
|
stamp_type=row.stamp_type_id,
|
||||||
|
position_x=int(kwargs['pos_x']),
|
||||||
|
position_y=int(kwargs['pos_y']),
|
||||||
|
rotation=float(kwargs['rotation'])
|
||||||
|
)
|
||||||
|
return cls(app, stamprow)
|
||||||
|
|
||||||
|
async def prepare(self) -> StampPayload:
|
||||||
|
typerow = await self.data.StampType.fetch(self.row.stamp_type)
|
||||||
|
assert typerow is not None
|
||||||
|
|
||||||
|
results: StampPayload = {
|
||||||
|
'stamp_id': self.row.stamp_id,
|
||||||
|
'document_id': self.row.document_id,
|
||||||
|
'stamp_type': typerow.stamp_type_name,
|
||||||
|
'pos_x': self.row.position_x,
|
||||||
|
'pos_y': self.row.position_y,
|
||||||
|
'rotation': self.row.rotation,
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def edit(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[StampEditParams]
|
||||||
|
):
|
||||||
|
data = self.data
|
||||||
|
row = self.row
|
||||||
|
edit_args = {}
|
||||||
|
if stamp_type := kwargs.get('stamp_type'):
|
||||||
|
# Get the stamp_type
|
||||||
|
rows = await data.StampType.table.fetch_rows_where(stamp_type_name=stamp_type)
|
||||||
|
if not rows:
|
||||||
|
# Create the stamp type
|
||||||
|
row = await data.StampType.create(stamp_type_name=stamp_type)
|
||||||
|
else:
|
||||||
|
row = rows[0]
|
||||||
|
edit_args['stamp_type'] = row.stamp_type_id
|
||||||
|
simple_keys = {
|
||||||
|
'document_id': 'document_id',
|
||||||
|
'pos_x': 'position_x',
|
||||||
|
'pos_y': 'position_y',
|
||||||
|
'rotation': 'rotation'
|
||||||
|
}
|
||||||
|
for editkey, datakey in simple_keys.items():
|
||||||
|
if editkey in kwargs:
|
||||||
|
edit_args[datakey] = kwargs[editkey]
|
||||||
|
|
||||||
|
await self.row.update(
|
||||||
|
**edit_args
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self) -> StampPayload:
|
||||||
|
payload = await self.prepare()
|
||||||
|
await self.row.delete()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/stamps')
|
||||||
|
@routes.view('/stamps/', name='stamps')
|
||||||
|
class StampsView(web.View):
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
# Decode request parameters to filter args
|
||||||
|
filter_params = {}
|
||||||
|
|
||||||
|
keys = ['stamp_id', 'document_id', 'stamp_type']
|
||||||
|
for key in keys:
|
||||||
|
if key in request.query:
|
||||||
|
filter_params[key] = request.query[key]
|
||||||
|
elif key in request:
|
||||||
|
filter_params[key] = request[key]
|
||||||
|
|
||||||
|
stamps = await Stamp.query(request.app, **filter_params)
|
||||||
|
payload = [await stamp.prepare() for stamp in stamps]
|
||||||
|
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def create_one(self, params: StampCreateParams):
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to stamp creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Stamp params missing required key '{missing}'.")
|
||||||
|
|
||||||
|
# This still doesn't guarantee that the values are of the correct type, but good enough.
|
||||||
|
stamp = await Stamp.create(self.request.app, **params)
|
||||||
|
return stamp
|
||||||
|
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
for key in create_fields:
|
||||||
|
if key in request:
|
||||||
|
params.setdefault(key, request[key])
|
||||||
|
|
||||||
|
stamp = await self.create_one(params)
|
||||||
|
payload = await stamp.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def put(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
from_request = {key: request[key] for key in create_fields if key in request}
|
||||||
|
argslist = await request.json()
|
||||||
|
|
||||||
|
payloads = []
|
||||||
|
for args in argslist:
|
||||||
|
stamp = await self.create_one(from_request | args)
|
||||||
|
payload = await stamp.prepare()
|
||||||
|
payloads.append(payload)
|
||||||
|
|
||||||
|
return web.json_response(payloads)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/stamps/{stamp_id}')
|
||||||
|
@routes.view('/stamps/{stamp_id}/', name='stamp')
|
||||||
|
class StampView(web.View):
|
||||||
|
|
||||||
|
async def resolve_stamp(self):
|
||||||
|
request = self.request
|
||||||
|
stamp_id = request.match_info['stamp_id']
|
||||||
|
stamp = await Stamp.fetch_from_id(request.app, int(stamp_id))
|
||||||
|
if stamp is None:
|
||||||
|
raise web.HTTPNotFound(text="No stamp exists with the given ID.")
|
||||||
|
return stamp
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
stamp = await self.resolve_stamp()
|
||||||
|
payload = await stamp.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
stamp = await self.resolve_stamp()
|
||||||
|
params = await self.request.json()
|
||||||
|
|
||||||
|
edit_data = {}
|
||||||
|
for key, value in params.items():
|
||||||
|
if key not in edit_fields:
|
||||||
|
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of Stamp!")
|
||||||
|
edit_data[key] = value
|
||||||
|
|
||||||
|
for key in edit_fields:
|
||||||
|
if key in self.request:
|
||||||
|
edit_data.setdefault(key, self.request[key])
|
||||||
|
|
||||||
|
await stamp.edit(**edit_data)
|
||||||
|
payload = await stamp.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
stamp = await self.resolve_stamp()
|
||||||
|
payload = await stamp.delete()
|
||||||
|
return web.json_response(payload)
|
||||||
223
src/routes/transactions.py
Normal file
223
src/routes/transactions.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List, TYPE_CHECKING
|
||||||
|
from aiohttp import web
|
||||||
|
from data import Condition, condition
|
||||||
|
from data.conditions import NULL
|
||||||
|
from data.queries import JOINTYPE
|
||||||
|
from datamodels import DataModel
|
||||||
|
|
||||||
|
from .lib import ModelField, datamodelsv, dbvar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .users import UserCreateParams, UserPayload, User
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionPayload(TypedDict):
|
||||||
|
transaction_id: int
|
||||||
|
user_id: int
|
||||||
|
user: 'UserPayload'
|
||||||
|
amount: int
|
||||||
|
description: str
|
||||||
|
reference: Optional[str]
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
class TransactionCreateParamsReq(TypedDict, total=True):
|
||||||
|
user_id: int
|
||||||
|
amount: int
|
||||||
|
description: str
|
||||||
|
|
||||||
|
class TransactionCreateParams(TransactionCreateParamsReq, total=False):
|
||||||
|
reference: str
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
ModelField('transaction_id', int, False, False, False),
|
||||||
|
ModelField('user_id', int, True, True, False),
|
||||||
|
ModelField('amount', int, True, True, False),
|
||||||
|
ModelField('description', str, True, True, False),
|
||||||
|
ModelField('reference', str, False, True, False),
|
||||||
|
ModelField('created_at', str, False, False, False),
|
||||||
|
]
|
||||||
|
req_fields = {field.name for field in fields if field.required}
|
||||||
|
edit_fields = {field.name for field in fields if field.can_edit}
|
||||||
|
create_fields = {field.name for field in fields if field.can_create}
|
||||||
|
|
||||||
|
|
||||||
|
class Transaction:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.Transaction):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
self.row = row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def validate_create_params(cls, params):
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to transaction creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Transaction params missing required key '{missing}'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, tid: int) -> Optional[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
row = await data.Transaction.fetch(int(tid))
|
||||||
|
return cls(app, row) if row is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
transaction_id: Optional[int] = None,
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
reference: Optional[str] = None,
|
||||||
|
created_before: Optional[str] = None,
|
||||||
|
created_after: Optional[str] = None,
|
||||||
|
) -> List[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
TXN = data.Transaction
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
|
||||||
|
if transaction_id is not None:
|
||||||
|
conds.append(TXN.transaction_id == int(transaction_id))
|
||||||
|
if user_id is not None:
|
||||||
|
conds.append(TXN.user_id == int(user_id))
|
||||||
|
if reference is not None:
|
||||||
|
conds.append(TXN.reference == reference)
|
||||||
|
if created_before is not None:
|
||||||
|
cbefore = datetime.fromisoformat(created_before)
|
||||||
|
conds.append(TXN.created_at <= cbefore)
|
||||||
|
if created_after is not None:
|
||||||
|
cafter = datetime.fromisoformat(created_after)
|
||||||
|
conds.append(TXN.created_at >= cafter)
|
||||||
|
|
||||||
|
rows = await TXN.fetch_where(*conds).order_by(TXN.created_at)
|
||||||
|
return [cls(app, row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, app: web.Application, **kwargs: Unpack[TransactionCreateParams]) -> Self:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
|
||||||
|
create_args = {}
|
||||||
|
|
||||||
|
for key in ('user_id', 'description', 'amount', 'reference'):
|
||||||
|
create_args[key] = kwargs.get(key)
|
||||||
|
|
||||||
|
logger.info(f"Creating Transaction with {create_args=}")
|
||||||
|
row = await data.Transaction.create(**create_args)
|
||||||
|
return cls(app, row)
|
||||||
|
|
||||||
|
async def edit(self):
|
||||||
|
raise ValueError("Transactions are immutable.")
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
raise ValueError("Transactions cannot be deleted directly.")
|
||||||
|
|
||||||
|
async def get_user(self):
|
||||||
|
from .users import User
|
||||||
|
return await User.fetch_from_id(self.app, self.row.user_id)
|
||||||
|
|
||||||
|
async def prepare(self) -> TransactionPayload:
|
||||||
|
user = await self.get_user()
|
||||||
|
if user is None:
|
||||||
|
raise ValueError("Transaction owner does not exist! This cannot happen.")
|
||||||
|
|
||||||
|
results: TransactionPayload = {
|
||||||
|
'transaction_id': self.row.transaction_id,
|
||||||
|
'user_id': self.row.user_id,
|
||||||
|
'user': await user.prepare(details=False),
|
||||||
|
'amount': self.row.amount,
|
||||||
|
'description': self.row.description,
|
||||||
|
'reference': self.row.reference,
|
||||||
|
'created_at': self.row.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/transactions')
|
||||||
|
@routes.view('/transactions/', name='transactions')
|
||||||
|
class TransactionsView(web.View):
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
for key in create_fields:
|
||||||
|
if key in request:
|
||||||
|
params.setdefault(key, request[key])
|
||||||
|
|
||||||
|
await Transaction.validate_create_params(params)
|
||||||
|
logger.info(f"Creating a new Transaction with args: {params=}")
|
||||||
|
txn = await Transaction.create(self.request.app, **params)
|
||||||
|
logger.debug(f"Created transaction: {txn!r}")
|
||||||
|
|
||||||
|
payload = await txn.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
filter_params = {}
|
||||||
|
keys = [
|
||||||
|
'transaction_id', 'user_id', 'reference',
|
||||||
|
'created_before', 'created_after',
|
||||||
|
]
|
||||||
|
for key in keys:
|
||||||
|
value = request.query.get(key, request.get(key, None))
|
||||||
|
filter_params[key] = value
|
||||||
|
|
||||||
|
logger.info(f"Querying transactions with params: {filter_params=}")
|
||||||
|
txns = await Transaction.query(request.app, **filter_params)
|
||||||
|
payload = [await txn.prepare() for txn in txns]
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/transactions/{transaction_id}')
|
||||||
|
@routes.view('/transactions/{transaction_id}/', name='transaction_id')
|
||||||
|
class TransactionView(web.View):
|
||||||
|
async def resolve_transaction(self):
|
||||||
|
request = self.request
|
||||||
|
txn_id = request.match_info['transaction_id']
|
||||||
|
txn = await Transaction.fetch_from_id(request.app, int(txn_id))
|
||||||
|
|
||||||
|
if txn is None:
|
||||||
|
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
|
||||||
|
|
||||||
|
return txn
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
txn = await self.resolve_transaction()
|
||||||
|
logger.info(f"Received GET for transaction {txn=}")
|
||||||
|
payload = await txn.prepare()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
raise web.HTTPBadRequest(text="Transactions are immutable and cannot be edited.")
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
raise web.HTTPBadRequest(text="Transactions cannot be individually deleted.")
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/transactions/{transaction_id}/user")
|
||||||
|
@routes.route('*', "/transactions/{transaction_id}/user{tail:/.*}")
|
||||||
|
async def transaction_user_route(request: web.Request):
|
||||||
|
txn_id = int(request.match_info['transaction_id'])
|
||||||
|
txn = await Transaction.fetch_from_id(request.app, txn_id)
|
||||||
|
if txn is None:
|
||||||
|
raise web.HTTPNotFound(text="No transaction exists with the given ID.")
|
||||||
|
|
||||||
|
tail = request.match_info.get('tail', '')
|
||||||
|
new_path = "/users/{user_id}".format(user_id=txn.row.user_id) + tail
|
||||||
|
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = txn.row.user_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
461
src/routes/users.py
Normal file
461
src/routes/users.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
"""
|
||||||
|
- `/users` with `POST`, `GET`, `PATCH`, `DELETE`
|
||||||
|
- `/users/{user_id}` with `GET`, `PATCH`, `DELETE`
|
||||||
|
- `/users/{user_id}/events` which is passed to `/events`
|
||||||
|
- `/users/{user_id}/specimen` which is passed to `/specimens/{specimen_id}`
|
||||||
|
- `/users/{user_id}/specimens` which is passed to `/specimens`
|
||||||
|
- `/users/{user_id}/wallet` with `GET`
|
||||||
|
- `/users/{user_id}/transactions` which is passed to `/transactions`
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
|
||||||
|
from aiohttp import web
|
||||||
|
import discord
|
||||||
|
from data import Condition, condition
|
||||||
|
from data.conditions import NULL
|
||||||
|
from data.queries import JOINTYPE
|
||||||
|
from datamodels import DataModel
|
||||||
|
|
||||||
|
from modules.profiles.data import ProfileData
|
||||||
|
from utils.lib import MessageArgs, tabulate
|
||||||
|
|
||||||
|
from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
|
||||||
|
from .specimens import Specimen, SpecimenPayload
|
||||||
|
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UserPayload(TypedDict):
|
||||||
|
user_id: int
|
||||||
|
twitch_id: Optional[str]
|
||||||
|
name: Optional[str]
|
||||||
|
preferences: Optional[str]
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserDetailsPayload(UserPayload):
|
||||||
|
specimen: Optional[SpecimenPayload]
|
||||||
|
inventory: List # TODO
|
||||||
|
wallet: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreateParamsReq(TypedDict, total=True):
|
||||||
|
twitch_id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreateParams(UserCreateParamsReq, total=False):
|
||||||
|
preferences: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserEditParams(TypedDict, total=False):
|
||||||
|
name: Optional[str]
|
||||||
|
preferences: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
fields = [
|
||||||
|
ModelField('user_id', int, False, False, False),
|
||||||
|
ModelField('twitch_id', str, True, True, False),
|
||||||
|
ModelField('name', str, True, True, True),
|
||||||
|
ModelField('preferences', str, False, True, True),
|
||||||
|
ModelField('created_at', str, False, False, False),
|
||||||
|
]
|
||||||
|
req_fields = {field.name for field in fields if field.required}
|
||||||
|
edit_fields = {field.name for field in fields if field.can_edit}
|
||||||
|
create_fields = {field.name for field in fields if field.can_create}
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
|
def __init__(self, app: web.Application, row: DataModel.Dreamer):
|
||||||
|
self.app = app
|
||||||
|
self.data = app[datamodelsv]
|
||||||
|
self.profile_data = app[profiledatav]
|
||||||
|
|
||||||
|
self.row = row
|
||||||
|
self._pref_row: Optional[DataModel.UserPreferences] = None
|
||||||
|
|
||||||
|
async def get_prefs(self) -> DataModel.UserPreferences:
|
||||||
|
if self._pref_row is None:
|
||||||
|
self._pref_row = await self.data.UserPreferences.fetch_or_create(self.row.user_id)
|
||||||
|
return self._pref_row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def validate_create_params(cls, params):
|
||||||
|
if extra := next((key for key in params if key not in create_fields), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"Invalid key '{extra}' passed to user creation.")
|
||||||
|
if missing := next((key for key in req_fields if key not in params), None):
|
||||||
|
raise web.HTTPBadRequest(text=f"User params missing required key '{missing}'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_id(cls, app: web.Application, user_id: int):
|
||||||
|
data = app[datamodelsv]
|
||||||
|
row = await data.Dreamer.fetch(int(user_id))
|
||||||
|
return cls(app, row) if row is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def query(
|
||||||
|
cls,
|
||||||
|
app: web.Application,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
twitch_id: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
created_before: Optional[str] = None,
|
||||||
|
created_after: Optional[str] = None,
|
||||||
|
) -> List[Self]:
|
||||||
|
data = app[datamodelsv]
|
||||||
|
Dreamer = data.Dreamer
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
if user_id is not None:
|
||||||
|
conds.append(Dreamer.user_id == int(user_id))
|
||||||
|
if twitch_id is not None:
|
||||||
|
conds.append(Dreamer.twitch_id == twitch_id)
|
||||||
|
if name is not None:
|
||||||
|
conds.append(Dreamer.name == name)
|
||||||
|
if created_before is not None:
|
||||||
|
cbefore = datetime.fromisoformat(created_before)
|
||||||
|
conds.append(Dreamer.created_at <= cbefore)
|
||||||
|
if created_after is not None:
|
||||||
|
cafter = datetime.fromisoformat(created_after)
|
||||||
|
conds.append(Dreamer.created_at >= cafter)
|
||||||
|
|
||||||
|
rows = await Dreamer.fetch_where(*conds).order_by(Dreamer.created_at)
|
||||||
|
return [cls(app, row) for row in rows]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, app: web.Application, **kwargs: Unpack[UserCreateParams]):
|
||||||
|
"""
|
||||||
|
Create a new User from the provided data.
|
||||||
|
|
||||||
|
This creates the associated UserProfile, TwitchProfile, and UserPreferences if needed.
|
||||||
|
If a profile already exists, this does *not* error.
|
||||||
|
Instead, this updates the existing User with the new data.
|
||||||
|
"""
|
||||||
|
data = app[datamodelsv]
|
||||||
|
|
||||||
|
twitch_id = kwargs['twitch_id']
|
||||||
|
name = kwargs['name']
|
||||||
|
prefs = kwargs.get('preferences')
|
||||||
|
|
||||||
|
# Quick sanity check on the twitch id
|
||||||
|
if not twitch_id or not twitch_id.isdigit():
|
||||||
|
raise web.HTTPBadRequest(text="Invalid 'twitch_id' passed to user creation!")
|
||||||
|
|
||||||
|
# First check if the profile already exists by querying the Dreamer database
|
||||||
|
edited = 0 # 0 means not edited, 1 means created, 2 means modified
|
||||||
|
rows = await data.Dreamer.fetch_where(twitch_id=twitch_id)
|
||||||
|
if rows:
|
||||||
|
logger.debug(f"Updating Dreamer for {twitch_id=} with {kwargs}")
|
||||||
|
dreamer = rows[0]
|
||||||
|
# A twitch profile with this twitch_id already exists
|
||||||
|
# But it is possible UserPreferences don't exist
|
||||||
|
if dreamer.preferences is None and dreamer.name is None:
|
||||||
|
await data.UserPreferences.fetch_or_create(dreamer.user_id, twitch_name=name, preferences=prefs)
|
||||||
|
dreamer = await dreamer.refresh()
|
||||||
|
edited = 2
|
||||||
|
|
||||||
|
# Now compare the existing data against the provided data and update if needed
|
||||||
|
if name != dreamer.name:
|
||||||
|
q = data.UserPreferences.table.update_where(profileid=dreamer.user_id)
|
||||||
|
q.set(twitch_name=name)
|
||||||
|
if prefs is not None:
|
||||||
|
q.set(preferences=prefs)
|
||||||
|
await q
|
||||||
|
dreamer = await dreamer.refresh()
|
||||||
|
edited = 2
|
||||||
|
else:
|
||||||
|
# Create from scratch
|
||||||
|
logger.info(f"Creating Dreamer for {twitch_id=} with {kwargs}")
|
||||||
|
# TODO: Should be in a transaction.. actually let's add transactions to the middleware..
|
||||||
|
profile_data = app[profiledatav]
|
||||||
|
user_profile = await profile_data.UserProfileRow.create(nickname=name)
|
||||||
|
await profile_data.TwitchProfileRow.create(
|
||||||
|
profileid=user_profile.profileid,
|
||||||
|
userid=twitch_id,
|
||||||
|
)
|
||||||
|
await data.UserPreferences.create(
|
||||||
|
profileid=user_profile.profileid,
|
||||||
|
twitch_name=name,
|
||||||
|
preferences=prefs
|
||||||
|
)
|
||||||
|
dreamer = await data.Dreamer.fetch(user_profile.profileid)
|
||||||
|
assert dreamer is not None
|
||||||
|
edited = 1
|
||||||
|
|
||||||
|
self = cls(app, dreamer)
|
||||||
|
if edited == 1:
|
||||||
|
args = await self.event_log_args(title=f"User #{dreamer.user_id} created!")
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
elif edited == 2:
|
||||||
|
args = await self.event_log_args(title=f"User #{dreamer.user_id} updated!")
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def edit(self, **kwargs: Unpack[UserEditParams]):
|
||||||
|
data = self.data
|
||||||
|
# We can edit the name, and preferences
|
||||||
|
prefs = await self.get_prefs()
|
||||||
|
update_args = {}
|
||||||
|
if 'name' in kwargs:
|
||||||
|
update_args['twitch_name'] = kwargs['name']
|
||||||
|
if 'preferences' in kwargs:
|
||||||
|
update_args['preferences'] = kwargs['preferences']
|
||||||
|
|
||||||
|
if update_args:
|
||||||
|
logger.info(f"Updating dreamer {self.row=} with {kwargs}")
|
||||||
|
await prefs.update(**update_args)
|
||||||
|
|
||||||
|
args = await self.event_log_args(title=f"User #{self.row.user_id} updated!")
|
||||||
|
await event_log(**args.send_args)
|
||||||
|
|
||||||
|
async def delete(self) -> UserDetailsPayload:
|
||||||
|
payload = await self.prepare(details=True)
|
||||||
|
# This will cascade to all other data the user has
|
||||||
|
await self.profile_data.UserProfileRow.table.delete_where(profileid=self.row.user_id)
|
||||||
|
# Make sure we take the user out of cache
|
||||||
|
await self.row.refresh()
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def get_wallet(self):
|
||||||
|
query = self.data.Transaction.table.select_where(user_id=self.row.user_id)
|
||||||
|
query.select(wallet="SUM(amount)")
|
||||||
|
query.with_no_adapter()
|
||||||
|
results = await query
|
||||||
|
|
||||||
|
return results[0]['wallet']
|
||||||
|
|
||||||
|
async def get_specimen(self) -> Optional[Specimen]:
|
||||||
|
data = self.data
|
||||||
|
active_specrows = await data.Specimen.fetch_where(
|
||||||
|
owner_id=self.row.user_id,
|
||||||
|
forgotten_at=NULL
|
||||||
|
)
|
||||||
|
if active_specrows:
|
||||||
|
row = active_specrows[0]
|
||||||
|
spec = Specimen(self.app, row)
|
||||||
|
else:
|
||||||
|
spec = None
|
||||||
|
return spec
|
||||||
|
|
||||||
|
async def get_inventory(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def event_log_args(self, **kwargs) -> MessageArgs:
|
||||||
|
desc = '\n'.join(await self.tabulate())
|
||||||
|
embed = discord.Embed(description=desc, timestamp=self.row.created_at, **kwargs)
|
||||||
|
embed.set_footer(text='Created At')
|
||||||
|
|
||||||
|
# TODO: We could add wallet, specimen, and inventory info here too
|
||||||
|
return MessageArgs(embed=embed)
|
||||||
|
|
||||||
|
async def tabulate(self):
|
||||||
|
"""
|
||||||
|
Present the User as a discord-readable table.
|
||||||
|
"""
|
||||||
|
table = {
|
||||||
|
'user_id': f"`{self.row.user_id}`",
|
||||||
|
'twitch_id': f"`{self.row.twitch_id}`" if self.row.twitch_id else 'No Twitch linked',
|
||||||
|
'name': f"`{self.row.name}`",
|
||||||
|
'preferences': f"`{self.row.preferences}`",
|
||||||
|
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
|
||||||
|
}
|
||||||
|
return tabulate(*table.items())
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def prepare(self, details: Literal[True]=True) -> UserDetailsPayload:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def prepare(self, details: Literal[False]=False) -> UserPayload:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def prepare(self, details=False) -> UserPayload | UserDetailsPayload:
|
||||||
|
# Since we are working with view rows, make sure we refresh
|
||||||
|
row = self.row
|
||||||
|
await row.refresh()
|
||||||
|
|
||||||
|
base_user: UserPayload = {
|
||||||
|
'user_id': row.user_id,
|
||||||
|
'twitch_id': str(row.twitch_id) if row.twitch_id else None,
|
||||||
|
'name': row.name,
|
||||||
|
'preferences': row.preferences,
|
||||||
|
'created_at': row.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if details:
|
||||||
|
# Now add details
|
||||||
|
specimen = await self.get_specimen()
|
||||||
|
sp_payload = await specimen.prepare() if specimen is not None else None
|
||||||
|
inventory = [await item.prepare() for item in await self.get_inventory()]
|
||||||
|
user: UserPayload = base_user | {
|
||||||
|
'specimen': sp_payload,
|
||||||
|
'inventory': inventory,
|
||||||
|
'wallet': await self.get_wallet(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
user = base_user
|
||||||
|
logger.debug(f"User prepared: {user}")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/users')
|
||||||
|
@routes.view('/users/', name='users')
|
||||||
|
class UsersView(web.View):
|
||||||
|
async def post(self):
|
||||||
|
request = self.request
|
||||||
|
|
||||||
|
params = await request.json()
|
||||||
|
for key in create_fields:
|
||||||
|
if key in request:
|
||||||
|
params.setdefault(key, request[key])
|
||||||
|
|
||||||
|
await User.validate_create_params(params)
|
||||||
|
logger.info(f"Creating a new user with args: {params=}")
|
||||||
|
user = await User.create(self.request.app, **params)
|
||||||
|
logger.debug(f"Created user: {user!r}")
|
||||||
|
payload = await user.prepare(details=True)
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
request = self.request
|
||||||
|
filter_params = {}
|
||||||
|
keys = [
|
||||||
|
'user_id', 'twitch_id', 'name', 'created_before', 'created_after',
|
||||||
|
]
|
||||||
|
for key in keys:
|
||||||
|
value = request.query.get(key, request.get(key, None))
|
||||||
|
filter_params[key] = value
|
||||||
|
|
||||||
|
logger.info(f"Querying users with params: {filter_params=}")
|
||||||
|
users = await User.query(request.app, **filter_params)
|
||||||
|
payload = [await user.prepare(details=True) for user in users]
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.view('/users/{user_id}')
|
||||||
|
@routes.view('/users/{user_id}/', name='user')
|
||||||
|
class UserView(web.View):
|
||||||
|
async def resolve_user(self):
|
||||||
|
request = self.request
|
||||||
|
user_id = request.match_info['user_id']
|
||||||
|
user = await User.fetch_from_id(request.app, int(user_id))
|
||||||
|
if user is None:
|
||||||
|
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get(self):
|
||||||
|
user = await self.resolve_user()
|
||||||
|
logger.info(f"Received GET for user {user=}")
|
||||||
|
payload = await user.prepare(details=True)
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def patch(self):
|
||||||
|
user = await self.resolve_user()
|
||||||
|
params = await self.request.json()
|
||||||
|
|
||||||
|
edit_data = {}
|
||||||
|
for key, value in params.items():
|
||||||
|
if key not in edit_fields:
|
||||||
|
raise web.HTTPBadRequest(text=f"You cannot update field '{key}' of User!")
|
||||||
|
edit_data[key] = value
|
||||||
|
|
||||||
|
for key in edit_fields:
|
||||||
|
if key in self.request:
|
||||||
|
edit_data.setdefault(key, self.request[key])
|
||||||
|
|
||||||
|
logger.info(f"Received PATCH for user {user} with params: {params}")
|
||||||
|
await user.edit(**edit_data)
|
||||||
|
payload = await user.prepare(details=True)
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
async def delete(self):
|
||||||
|
user = await self.resolve_user()
|
||||||
|
logger.info(f"Received DELETE for user {user}")
|
||||||
|
payload = await user.delete()
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/events}")
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/events/.*}")
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/transactions}")
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/transactions/.*}")
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/specimens}")
|
||||||
|
@routes.route('*', "/users/{user_id}{tail:/specimens/.*}")
|
||||||
|
async def user_prefix_routes(request: web.Request):
|
||||||
|
user_id = int(request.match_info['user_id'])
|
||||||
|
user = await User.fetch_from_id(request.app, user_id)
|
||||||
|
if user is None:
|
||||||
|
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||||
|
|
||||||
|
new_path = request.match_info['tail']
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path=} and setting {user_id=}")
|
||||||
|
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = user_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('*', "/users/{user_id}/specimen")
|
||||||
|
@routes.route('*', "/users/{user_id}/specimen{tail:/.*}")
|
||||||
|
async def user_specimen_route(request: web.Request):
|
||||||
|
user_id = int(request.match_info['user_id'])
|
||||||
|
user = await User.fetch_from_id(request.app, user_id)
|
||||||
|
if user is None:
|
||||||
|
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||||
|
tail = request.match_info.get('tail', '')
|
||||||
|
|
||||||
|
specimen = await user.get_specimen()
|
||||||
|
if request.method == 'POST' and not tail.strip('/'):
|
||||||
|
if specimen is None:
|
||||||
|
# Redirect to POST /specimens
|
||||||
|
# TODO: Would be nicer to use named handler here
|
||||||
|
new_path = '/specimens'
|
||||||
|
logger.info(f"Redirecting {request=} to POST /specimens")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = user_id
|
||||||
|
new_request['owner_id'] = user_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
raise web.HTTPBadRequest(text="This user already has an active specimen!")
|
||||||
|
elif specimen is None:
|
||||||
|
raise web.HTTPNotFound(text="This user has no active specimen.")
|
||||||
|
else:
|
||||||
|
specimen_id = specimen.row.specimen_id
|
||||||
|
# Redirect to POST /specimens/{specimen_id}/...
|
||||||
|
new_path = f"/specimens/{specimen_id}".format(specimen_id=specimen_id) + tail
|
||||||
|
logger.info(f"Redirecting {request=} to {new_path}")
|
||||||
|
new_request = request.clone(rel_url=new_path)
|
||||||
|
new_request['user_id'] = user_id
|
||||||
|
new_request['owner_id'] = user_id
|
||||||
|
new_request['specimen_id'] = specimen_id
|
||||||
|
match_info = await request.app.router.resolve(new_request)
|
||||||
|
new_request._match_info = match_info
|
||||||
|
match_info.current_app = request.app
|
||||||
|
if match_info.handler:
|
||||||
|
return await match_info.handler(new_request)
|
||||||
|
else:
|
||||||
|
logger.info(f"Could not find handler matching {new_request}")
|
||||||
|
raise web.HTTPNotFound()
|
||||||
|
|
||||||
|
|
||||||
|
@routes.route('GET', "/users/{user_id}/wallet")
|
||||||
|
@routes.route('GET', "/users/{user_id}/wallet/")
|
||||||
|
async def user_wallet_route(request: web.Request):
|
||||||
|
user_id = int(request.match_info['user_id'])
|
||||||
|
user = await User.fetch_from_id(request.app, user_id)
|
||||||
|
if user is None:
|
||||||
|
raise web.HTTPNotFound(text="No user exists with the given ID.")
|
||||||
|
wallet = await user.get_wallet()
|
||||||
|
return web.json_response(wallet)
|
||||||
27
src/utils/auth.py
Normal file
27
src/utils/auth.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import Awaitable, Callable
|
||||||
|
from aiohttp import hdrs, web
|
||||||
|
|
||||||
|
CHALLENGE = web.Response(
|
||||||
|
body="<b> 401 UNAUTHORIZED </b>",
|
||||||
|
status=401,
|
||||||
|
reason='UNAUTHORIZED',
|
||||||
|
headers={
|
||||||
|
hdrs.WWW_AUTHENTICATE: 'X-API-KEY',
|
||||||
|
hdrs.CONTENT_TYPE: 'text/html; charset=utf-8',
|
||||||
|
hdrs.CONNECTION: 'close',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def key_auth_factory(required_token: str):
|
||||||
|
"""
|
||||||
|
Creates an aiohttp middleware that ensures
|
||||||
|
the `required_token` is provided in the `X-API-KEY` header.
|
||||||
|
"""
|
||||||
|
@web.middleware
|
||||||
|
async def key_auth_middleware(request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]]) -> web.StreamResponse:
|
||||||
|
auth_header = request.headers.get('X-API-KEY')
|
||||||
|
if not auth_header or auth_header.strip() != required_token:
|
||||||
|
return CHALLENGE
|
||||||
|
else:
|
||||||
|
return await handler(request)
|
||||||
|
return key_auth_middleware
|
||||||
Reference in New Issue
Block a user