Compare commits

..

13 Commits

21 changed files with 1433 additions and 109 deletions

View File

@@ -231,14 +231,14 @@ CREATE TABLE plain_events (
event_id integer PRIMARY KEY, event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'), event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
message TEXT NOT NULL, message TEXT NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE raid_events ( CREATE TABLE raid_events (
event_id integer PRIMARY KEY, event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'), event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
visitor_count INTEGER NOT NULL, visitor_count INTEGER NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE cheer_events ( CREATE TABLE cheer_events (
@@ -247,7 +247,7 @@ CREATE TABLE cheer_events (
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
cheer_type TEXT, cheer_type TEXT,
message TEXT, message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE subscriber_events ( CREATE TABLE subscriber_events (
@@ -256,7 +256,7 @@ CREATE TABLE subscriber_events (
subscribed_length INTEGER NOT NULL, subscribed_length INTEGER NOT NULL,
tier INTEGER NOT NULL, tier INTEGER NOT NULL,
message TEXT, message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );

View File

@@ -11,54 +11,14 @@ from datamodels import DataModel
from constants import DATA_VERSION from constants import DATA_VERSION
from modules.profiles.data import ProfileData from modules.profiles.data import ProfileData
from routes import dbvar, datamodelsv, profiledatav, register_routes
from routes.stamps import routes as stamp_routes
from routes.documents import routes as doc_routes
from routes.users import routes as user_routes
from routes.specimens import routes as spec_routes
from routes.transactions import routes as txn_routes
from routes.events import routes as event_routes
from routes.lib import dbvar, datamodelsv, profiledatav
sys.path.insert(0, os.path.join(os.getcwd())) sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src")) sys.path.insert(0, os.path.join(os.getcwd(), "src"))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: Move the route table to the __init__ of routes
# Maybe we can join route tables together?
# Or we just expose an add_routes or register method
"""
- `/stamps` with `POST`, `PUT`, `GET`
- `/stamps/{stamp_id}` with `GET`, `PATCH`, `DELETE`
- `/documents` with `POST, GET`
- `/documents/{document_id}` with `GET`, `PATCH`, `DELETE`
- `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set.
- `/events` with `POST`, `GET`
- `/events/{event_id}` with `GET`, `PATCH`, `DELETE`
- `/events/{event_id}/document` which is passed to `/documents/{document_id}`
- `/events/{event_id}/user` which is passed to `/users/{user_id}`
- `/users` with `POST`, `GET`, `PATCH`, `DELETE`
- `/users/{user_id}` with `GET`, `PATCH`, `DELETE`
- `/users/{user_id}/events` which is passed to `/events`
- `/users/{user_id}/specimen` which is passed to `/specimens/{specimen_id}`
- `/users/{user_id}/specimens` which is passed to `/specimens`
- `/users/{user_id}/wallet` with `GET`
- `/users/{user_id}/transactions` which is passed to `/transactions`
- `/specimens` with `GET` and `POST`
- `/specimens/{specimen_id}` with `PATCH` and `DELETE`
- `/specimens/{specimen_id}/owner` which is passed to `/users/{user_id}`
- `/transactions` with `POST`, `GET`
- `/transactions/{transaction_id}` with `GET`, `PATCH`, `DELETE`
- `/transactions/{transaction_id}/user` which is passed to `/users/{user_id}`
"""
async def attach_db(app: web.Application): async def attach_db(app: web.Application):
db = Database(conf.data['args']) db = Database(conf.data['args'])
async with db.open(): async with db.open():
@@ -84,27 +44,18 @@ async def attach_db(app: web.Application):
async def test(request: web.Request) -> web.Response: async def test(request: web.Request) -> web.Response:
return web.Response(text="Hello World") return web.Response(text="Welcome to the Dreamspace API. Please donate an important childhood memory to continue.")
async def app_factory(): def app_factory():
auth = key_auth_factory(conf.API['TOKEN']) auth = key_auth_factory(conf.API['TOKEN'])
app = web.Application(middlewares=[auth]) app = web.Application(middlewares=[auth])
app.cleanup_ctx.append(attach_db) app.cleanup_ctx.append(attach_db)
app.router.add_get('/', test) app.router.add_get('/', test)
app.router.add_routes(stamp_routes) register_routes(app.router)
app.router.add_routes(doc_routes)
app.router.add_routes(user_routes)
app.router.add_routes(spec_routes)
app.router.add_routes(event_routes)
app.router.add_routes(txn_routes)
return app return app
async def run_app(): if __name__ == '__main__':
app = await app_factory() app = app_factory()
web.run_app(app, port=int(conf.API['PORT'])) web.run_app(app, port=int(conf.API['PORT']))
if __name__ == '__main__':
asyncio.run(run_app())

6
src/brand.py Normal file
View File

@@ -0,0 +1,6 @@
import discord
# Theme
MAIN_COLOUR = discord.Colour.from_str('#11EA11')
ACCENT_COLOUR = discord.Colour.from_str('#EA11EA')

View File

@@ -11,6 +11,7 @@ from meta.app import shardname, appname
from meta.logger import log_wrap from meta.logger import log_wrap
from utils.lib import utc_now from utils.lib import utc_now
from datamodels import DataModel
from .data import CoreData from .data import CoreData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,7 +30,9 @@ class CoreCog(LionCog):
def __init__(self, bot: LionBot): def __init__(self, bot: LionBot):
self.bot = bot self.bot = bot
self.data = CoreData() self.data = CoreData()
self.datamodel = DataModel()
bot.db.load_registry(self.data) bot.db.load_registry(self.data)
bot.db.load_registry(self.datamodel)
self.app_config: Optional[CoreData.AppConfig] = None self.app_config: Optional[CoreData.AppConfig] = None
self.bot_config: Optional[CoreData.BotConfig] = None self.bot_config: Optional[CoreData.BotConfig] = None
@@ -43,6 +46,9 @@ class CoreCog(LionCog):
self.app_config = await self.data.AppConfig.fetch_or_create(appname) self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname) self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
await self.data.init()
await self.datamodel.init()
# Load the app command cache # Load the app command cache
await self.reload_appcmd_cache() await self.reload_appcmd_cache()

View File

@@ -1,4 +1,8 @@
from io import BytesIO
import base64
from enum import Enum from enum import Enum
from typing import NamedTuple
from data import Registry, RowModel, Table, RegisterEnum from data import Registry, RowModel, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Column from data.columns import Integer, String, Timestamp, Column
@@ -9,6 +13,52 @@ class EventType(Enum):
CHEER = 'cheer', CHEER = 'cheer',
PLAIN = 'plain', PLAIN = 'plain',
def info(self):
if self is EventType.SUBSCRIBER:
info = EventTypeInfo(
EventType.SUBSCRIBER,
DataModel.subscriber_events,
("tier", "subscribed_length", "message"),
("tier", "subscribed_length", "message"),
('subscriber_tier', 'subscriber_length', 'subscriber_message'),
)
elif self is EventType.RAID:
info = EventTypeInfo(
EventType.RAID,
DataModel.raid_events,
('visitor_count',),
('viewer_count',),
('raid_visitor_count',),
)
elif self is EventType.CHEER:
info = EventTypeInfo(
EventType.CHEER,
DataModel.cheer_events,
('amount', 'cheer_type', 'message'),
('amount', 'cheer_type', 'message'),
('cheer_amount', 'cheer_type', 'cheer_message'),
)
elif self is EventType.PLAIN:
info = EventTypeInfo(
EventType.PLAIN,
DataModel.plain_events,
('message',),
('message',),
('plain_message',),
)
else:
raise ValueError("Unexpected event type.")
return info
class EventTypeInfo(NamedTuple):
typ: EventType
table: Table
columns: tuple[str, ...]
params: tuple[str, ...]
detailcolumns: tuple[str, ...]
class DataModel(Registry): class DataModel(Registry):
_EventType = RegisterEnum(EventType, 'EventType') _EventType = RegisterEnum(EventType, 'EventType')
@@ -118,6 +168,14 @@ class DataModel(Registry):
metadata = String() metadata = String()
created_at = Timestamp() created_at = Timestamp()
def to_bytes(self):
"""
Helper method to decode the saved document data to a byte string.
This may fail if the saved string is not base64 encoded.
"""
byts = BytesIO(base64.b64decode(self.document_data))
return byts
class DocumentStamp(RowModel): class DocumentStamp(RowModel):
""" """
Schema Schema
@@ -182,14 +240,14 @@ class DataModel(Registry):
event_id integer PRIMARY KEY, event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'), event_type EventType NOT NULL DEFAULT 'plain' CHECK (event_type = 'plain'),
message TEXT NOT NULL, message TEXT NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE raid_events ( CREATE TABLE raid_events (
event_id integer PRIMARY KEY, event_id integer PRIMARY KEY,
event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'), event_type EventType NOT NULL DEFAULT 'raid' CHECK (event_type = 'raid'),
visitor_count INTEGER NOT NULL, visitor_count INTEGER NOT NULL,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE cheer_events ( CREATE TABLE cheer_events (
@@ -198,7 +256,7 @@ class DataModel(Registry):
amount INTEGER NOT NULL, amount INTEGER NOT NULL,
cheer_type TEXT, cheer_type TEXT,
message TEXT, message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE TABLE subscriber_events ( CREATE TABLE subscriber_events (
@@ -207,7 +265,7 @@ class DataModel(Registry):
subscribed_length INTEGER NOT NULL, subscribed_length INTEGER NOT NULL,
tier INTEGER NOT NULL, tier INTEGER NOT NULL,
message TEXT, message TEXT,
FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) FOREIGN KEY (event_id, event_type) REFERENCES events (event_id, event_type) ON DELETE CASCADE
); );
CREATE VIEW event_details AS CREATE VIEW event_details AS

76
src/meta/CrocBot.py Normal file
View File

@@ -0,0 +1,76 @@
from collections import defaultdict
from typing import TYPE_CHECKING
import logging
import twitchio
from twitchio.ext import commands
from twitchio.ext import pubsub
from twitchio.ext.commands.core import itertools
from data import Database
from .config import Conf
logger = logging.getLogger(__name__)
class CrocBot(commands.Bot):
def __init__(self, *args,
config: Conf,
data: Database,
**kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data = data
self.pubsub = pubsub.PubSubPool(self)
self._member_cache = defaultdict(dict)
async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
self._member_cache[channel.name][user.name] = user
async def event_message(self, message: twitchio.Message):
if message.channel and message.author:
self._member_cache[message.channel.name][message.author.name] = message.author
await self.handle_commands(message)
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
if userstr.startswith('@'):
matching = False
userstr = userstr.strip('@ ')
result = None
if matching and len(userstr) >= 3:
lowered = userstr.lower()
full_matches = []
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
matchstr = user.name.lower()
print(matchstr)
if matchstr.startswith(lowered):
result = user
break
if lowered in matchstr:
full_matches.append(user)
if result is None and full_matches:
result = full_matches[0]
print(result)
if result is None:
lookup = userstr
elif result.id is None:
lookup = result.name
else:
lookup = None
if lookup:
found = await self.fetch_users(names=[lookup])
if found:
result = found[0]
# No matches found
return result

View File

@@ -3,6 +3,7 @@ this_package = 'modules'
active = [ active = [
'.profiles', '.profiles',
'.sysadmin', '.sysadmin',
'.dreamspace',
] ]

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import DreamCog
await bot.add_cog(DreamCog(bot))

View File

@@ -0,0 +1,70 @@
import asyncio
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from meta import LionCog, LionBot, LionContext
from meta.logger import log_wrap
from utils.lib import utc_now
from . import logger
from .ui.docviewer import DocumentViewer
class DreamCog(LionCog):
"""
Discord-facting interface for Dreamspace Adventures
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.core.datamodel
async def cog_load(self):
pass
@log_wrap(action="Dreamer migration")
async def migrate_dreamer(self, source_profile, target_profile):
"""
Called when two dreamer profiles need to merge.
For example, when a user links a second twitch profile.
:TODO-MARKER:
Most of the migration logic is simple, e.g. just update the profileid
on the old events to the new profile.
The same applies to transactions and probably to inventory items.
However, there are some subtle choices to make, such as what to do
if both the old and the new profile have an active specimen?
A profile can only have one active specimen at a time.
There is also the question of how to merge user preferences, when those exist.
"""
...
# User command: view their dreamer card, wallet inventory etc
# (Admin): View events/documents matching certain criteria
# (User): View own event cards with info?
# Let's make a demo viewer which lists their event cards and let's them open one via select?
# /documents -> Show a paged list of documents, select option displays the document in a viewer
@cmds.hybrid_command(
name='documents',
description="View your printer log!"
)
async def documents_cmd(self, ctx: LionContext):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
events = await self.data.Events.fetch_where(user_id=profile.profileid)
docids = [event.document_id for event in events if event.document_id is not None]
if not docids:
await ctx.error_reply("You don't have any documents yet!")
return
view = DocumentViewer(self.bot, ctx.interaction.user.id, filter=(self.data.Document.document_id == docids))
await view.run(ctx.interaction)
# (User): View Specimen information

View File

@@ -0,0 +1,36 @@
from datamodels import DataModel
class Event:
_typs = {}
def __init__(self, event_row: DataModel.Events, **kwargs):
self.row = event_row
def __getattribute__(self, name: str):
...
async def get_document(self):
...
async def get_user(self):
...
class Document:
def as_bytes(self):
...
async def get_stamps(self):
...
async def refresh(self):
...
class User:
...
class Stamp:
...

View File

@@ -0,0 +1,171 @@
import binascii
from itertools import chain
from typing import Optional
from dataclasses import dataclass
import asyncio
import datetime as dt
import discord
from discord.ui.select import select, Select, UserSelect
from discord.ui.button import button, Button
from discord.ui.text_input import TextInput
from discord.enums import ButtonStyle, TextStyle
from discord.components import SelectOption
from datamodels import DataModel
from meta import LionBot, conf
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from data import ORDER, Condition
from utils.ui import MessageUI, input
from utils.lib import MessageArgs, tabulate, utc_now
from .. import logger
class DocumentViewer(MessageUI):
"""
Simple pager which displays a filtered list of Documents.
"""
block_len = 5
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.data: DataModel = bot.core.datamodel
self.filter = filter
# Paging state
self._pagen = 0
self.blocks = [[]]
@property
def page_count(self):
return len(self.blocks)
@property
def pagen(self):
self._pagen %= self.page_count
return self._pagen
@pagen.setter
def pagen(self, value):
self._pagen = value % self.page_count
@property
def current_page(self):
return self.blocks[self.pagen]
# ----- UI Components -----
# Page backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
async def jump_button(self, press: discord.Interaction, pressed: Button):
"""
Jump to page button.
"""
try:
interaction, value = await input(
press,
title="Jump to page",
question="Page number to jump to"
)
value = value.strip()
except asyncio.TimeoutError:
return
if not value.lstrip('- ').isdigit():
error = discord.Embed(title="Invalid page number, please try again!",
colour=discord.Colour.brand_red())
await interaction.response.send_message(embed=error, ephemeral=True)
else:
await interaction.response.defer(thinking=True)
pagen = int(value.lstrip('- '))
if value.startswith('-'):
pagen = -1 * pagen
elif pagen > 0:
pagen = pagen - 1
self.pagen = pagen
await self.refresh(thinking=interaction)
async def jump_button_refresh(self):
component = self.jump_button
component.label = f"{self.pagen + 1}/{self.page_count}"
component.disabled = (self.page_count <= 1)
# Page forwards
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
# if self.child_viewer:
# await self.child_viewer.quit()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
files = []
embeds = []
for doc in self.current_page:
try:
imagedata = doc.to_bytes()
imagedata.seek(0)
except binascii.Error:
continue
fn = f"doc-{doc.document_id}.png"
file = discord.File(imagedata, fn)
embed = discord.Embed()
embed.set_image(url=f"attachment://{fn}")
files.append(file)
embeds.append(embed)
if not embeds:
embed = discord.Embed(description="You don't have any documents yet!")
embeds.append(embed)
print(f"FILES: {files}")
return MessageArgs(files=files, embeds=embeds)
async def refresh_layout(self):
to_refresh = (
self.jump_button_refresh(),
)
await asyncio.gather(*to_refresh)
if self.page_count > 1:
page_line = (
self.prev_button,
self.jump_button,
self.quit_button,
self.next_button,
)
else:
page_line = (self.quit_button,)
self.set_layout(page_line)
async def reload(self):
docs = await self.data.Document.fetch_where(self.filter).order_by('created_at', ORDER.DESC)
blocks = [
docs[i:i+self.block_len]
for i in range(0, len(docs), self.block_len)
]
self.blocks = blocks or [[]]

View File

@@ -0,0 +1,406 @@
from itertools import chain
from typing import Optional
from dataclasses import dataclass
import asyncio
import datetime as dt
import discord
from discord.ui.select import select, Select, UserSelect
from discord.ui.button import button, Button
from discord.ui.text_input import TextInput
from discord.enums import ButtonStyle, TextStyle
from discord.components import SelectOption
from datamodels import DataModel
from meta import LionBot, conf
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from data import ORDER, Condition
from utils.ui import MessageUI, input
from utils.lib import MessageArgs, tabulate, utc_now
from .. import logger
class EventsUI(MessageUI):
block_len = 10
def __init__(self, bot: LionBot, callerid: int, filter: Condition, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.data: DataModel = bot.core.datamodel
self.filter = Condition
# Paging state
self._pagen = 0
self.blocks = [[]]
@property
def page_count(self):
return len(self.blocks)
@property
def pagen(self):
self._pagen %= self.page_count
return self._pagen
@pagen.setter
def pagen(self, value):
self._pagen = value % self.page_count
@property
def current_page(self):
return self.blocks[self.pagen]
# ----- UI Components -----
# Page backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
# Page forwards
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
# if self.child_viewer:
# await self.child_viewer.quit()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
...
async def refresh_layout(self):
...
async def reload(self):
...
class TicketListUI(MessageUI):
# Select Ticket
@select(
cls=Select,
placeholder="TICKETS_MENU_PLACEHOLDER",
min_values=1, max_values=1
)
async def tickets_menu(self, selection: discord.Interaction, selected: Select):
await selection.response.defer(thinking=True, ephemeral=True)
if selected.values:
ticketid = int(selected.values[0])
ticket = await Ticket.fetch_ticket(self.bot, ticketid)
ticketui = TicketUI(self.bot, ticket, self._callerid)
if self.child_ticket:
await self.child_ticket.quit()
self.child_ticket = ticketui
await ticketui.run(selection)
async def tickets_menu_refresh(self):
menu = self.tickets_menu
t = self.bot.translator.t
menu.placeholder = t(_p(
'ui:tickets|menu:tickets|placeholder',
"Select Ticket"
))
options = []
for ticket in self.current_page:
option = SelectOption(
label=f"Ticket #{ticket.data.guild_ticketid}",
value=str(ticket.data.ticketid)
)
options.append(option)
menu.options = options
# Backwards
@button(emoji=conf.emojis.backward, style=ButtonStyle.grey)
async def prev_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True, ephemeral=True)
self.pagen -= 1
await self.refresh(thinking=press)
# Jump to page
@button(label="JUMP_PLACEHOLDER", style=ButtonStyle.blurple)
async def jump_button(self, press: discord.Interaction, pressed: Button):
"""
Jump-to-page button.
Loads a page-switch dialogue.
"""
t = self.bot.translator.t
try:
interaction, value = await input(
press,
title=t(_p(
'ui:tickets|button:jump|input:title',
"Jump to page"
)),
question=t(_p(
'ui:tickets|button:jump|input:question',
"Page number to jump to"
))
)
value = value.strip()
except asyncio.TimeoutError:
return
if not value.lstrip('- ').isdigit():
error_embed = discord.Embed(
title=t(_p(
'ui:tickets|button:jump|error:invalid_page',
"Invalid page number, please try again!"
)),
colour=discord.Colour.brand_red()
)
await interaction.response.send_message(embed=error_embed, ephemeral=True)
else:
await interaction.response.defer(thinking=True)
pagen = int(value.lstrip('- '))
if value.startswith('-'):
pagen = -1 * pagen
elif pagen > 0:
pagen = pagen - 1
self.pagen = pagen
await self.refresh(thinking=interaction)
async def jump_button_refresh(self):
component = self.jump_button
component.label = f"{self.pagen + 1}/{self.page_count}"
component.disabled = (self.page_count <= 1)
# Forward
@button(emoji=conf.emojis.forward, style=ButtonStyle.grey)
async def next_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True)
self.pagen += 1
await self.refresh(thinking=press)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
if self.child_ticket:
await self.child_ticket.quit()
await self.quit()
# ----- UI Flow -----
def _format_ticket(self, ticket) -> str:
"""
Format a ticket into a single embed line.
"""
components = (
"[#{ticketid}]({link})",
"{created}",
"`{type}[{state}]`",
"<@{targetid}>",
"{content}",
)
formatstr = ' | '.join(components)
data = ticket.data
if not data.content:
content = 'No Content'
elif len(data.content) > 100:
content = data.content[:97] + '...'
else:
content = data.content
ticketstr = formatstr.format(
ticketid=data.guild_ticketid,
link=ticket.jump_url or 'https://lionbot.org',
created=discord.utils.format_dt(data.created_at, 'd'),
type=data.ticket_type.name,
state=data.ticket_state.name,
targetid=data.targetid,
content=content,
)
if data.ticket_state is TicketState.PARDONED:
ticketstr = f"~~{ticketstr}~~"
return ticketstr
async def make_message(self) -> MessageArgs:
t = self.bot.translator.t
embed = discord.Embed(
title=t(_p(
'ui:tickets|embed|title',
"Moderation Tickets in {guild}"
)).format(guild=self.guild.name),
timestamp=utc_now()
)
tickets = self.current_page
if tickets:
desc = '\n'.join(self._format_ticket(ticket) for ticket in tickets)
else:
desc = t(_p(
'ui:tickets|embed|desc:no_tickets',
"No tickets matching the given criteria!"
))
embed.description = desc
filterstr = self.filters.formatted()
if filterstr:
embed.add_field(
name=t(_p(
'ui:tickets|embed|field:filters|name',
"Filters"
)),
value=filterstr,
inline=False
)
return MessageArgs(embed=embed)
async def refresh_layout(self):
to_refresh = (
self.edit_filter_button_refresh(),
self.select_ticket_button_refresh(),
self.pardon_button_refresh(),
self.tickets_menu_refresh(),
self.filter_type_menu_refresh(),
self.filter_state_menu_refresh(),
self.filter_target_menu_refresh(),
self.jump_button_refresh(),
)
await asyncio.gather(*to_refresh)
action_line = (
self.edit_filter_button,
self.select_ticket_button,
self.pardon_button,
)
if self.page_count > 1:
page_line = (
self.prev_button,
self.jump_button,
self.quit_button,
self.next_button,
)
else:
page_line = ()
action_line = (*action_line, self.quit_button)
if self.show_filters:
menus = (
(self.filter_type_menu,),
(self.filter_state_menu,),
(self.filter_target_menu,),
)
elif self.show_tickets and self.current_page:
menus = ((self.tickets_menu,),)
else:
menus = ()
self.set_layout(
action_line,
*menus,
page_line,
)
async def reload(self):
tickets = await Ticket.fetch_tickets(
self.bot,
*self.filters.conditions(),
guildid=self.guild.id,
)
blocks = [
tickets[i:i+self.block_len]
for i in range(0, len(tickets), self.block_len)
]
self.blocks = blocks or [[]]
class TicketUI(MessageUI):
def __init__(self, bot: LionBot, ticket: Ticket, callerid: int, **kwargs):
super().__init__(callerid=callerid, **kwargs)
self.bot = bot
self.ticket = ticket
# ----- API -----
# ----- UI Components -----
# Pardon Ticket
@button(
label="PARDON_BUTTON_PLACEHOLDER",
style=ButtonStyle.red
)
async def pardon_button(self, press: discord.Interaction, pressed: Button):
t = self.bot.translator.t
modal_title = t(_p(
'ui:ticket|button:pardon|modal:reason|title',
"Pardon Moderation Ticket"
))
input_field = TextInput(
label=t(_p(
'ui:ticket|button:pardon|modal:reason|field|label',
"Why are you pardoning this ticket?"
)),
style=TextStyle.long,
min_length=0,
max_length=1024,
)
try:
interaction, reason = await input(
press, modal_title, field=input_field, timeout=300,
)
except asyncio.TimeoutError:
raise ResponseTimedOut
await interaction.response.defer(thinking=True, ephemeral=True)
await self.ticket.pardon(modid=press.user.id, reason=reason)
await self.refresh(thinking=interaction)
async def pardon_button_refresh(self):
button = self.pardon_button
t = self.bot.translator.t
button.label = t(_p(
'ui:ticket|button:pardon|label',
"Pardon"
))
button.disabled = (self.ticket.data.ticket_state is TicketState.PARDONED)
# Quit
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def quit_button(self, press: discord.Interaction, pressed: Button):
"""
Quit the UI.
"""
await press.response.defer()
await self.quit()
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
return await self.ticket.make_message()
async def refresh_layout(self):
await self.pardon_button_refresh()
self.set_layout(
(self.pardon_button, self.quit_button,)
)
async def reload(self):
await self.ticket.data.refresh()

View File

View File

@@ -21,6 +21,8 @@ from .data import ProfileData
from .profile import UserProfile from .profile import UserProfile
from .community import Community from .community import Community
from .ui import TwitchLinkStatic, TwitchLinkFlow
class ProfileCog(LionCog): class ProfileCog(LionCog):
def __init__(self, bot: LionBot): def __init__(self, bot: LionBot):
@@ -34,6 +36,8 @@ class ProfileCog(LionCog):
async def cog_load(self): async def cog_load(self):
await self.data.init() await self.data.init()
self.bot.add_view(TwitchLinkStatic(timeout=None))
async def cog_check(self, ctx): async def cog_check(self, ctx):
return True return True
@@ -197,6 +201,16 @@ class ProfileCog(LionCog):
community = await Community.create_from_twitch(self.bot, user) community = await Community.create_from_twitch(self.bot, user)
return community return community
# ----- Admin Commands -----
@cmds.hybrid_command(
name='linkoffer',
description="Send a message with a permanent button for profile linking"
)
@appcmds.default_permissions(manage_guild=True)
async def linkoffer_cmd(self, ctx: LionContext):
view = TwitchLinkStatic(timeout=None)
await ctx.channel.send(embed=view.embed, view=view)
# ----- Profile Commands ----- # ----- Profile Commands -----
@cmds.hybrid_group( @cmds.hybrid_group(
name='profiles', name='profiles',
@@ -217,6 +231,13 @@ class ProfileCog(LionCog):
description="Link a twitch account to your current profile." description="Link a twitch account to your current profile."
) )
async def profiles_link_twitch_cmd(self, ctx: LionContext): async def profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
flowui = TwitchLinkFlow(self.bot, ctx.author, callerid=ctx.author.id)
await flowui.run(ctx.interaction)
await flowui.wait()
async def old_profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction: if not ctx.interaction:
return return

View File

@@ -0,0 +1 @@
from .twitchlink import TwitchLinkStatic, TwitchLinkFlow

View File

@@ -0,0 +1,337 @@
"""
UI Views for Twitch linkage.
- Persistent view with interaction-button to enter link flow.
- We don't store the view, but we listen to interaction button id.
- Command to enter link flow.
For link flow, send ephemeral embed with instructions and what to expect, with link button below.
After auth is granted through OAuth flow (or if not granted, e.g. on timeout or failure)
edit the embed to reflect auth situation.
If migration occurred, add the migration text as a field to the embed.
"""
import asyncio
from datetime import timedelta
from enum import IntEnum
from typing import Optional
import aiohttp
import discord
from discord.ui.button import button, Button
from discord.enums import ButtonStyle
from twitchAPI.helper import first
from meta import LionBot
from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils.ui import MessageUI
from utils.lib import MessageArgs, utc_now
from utils.ui.leo import LeoUI
from modules.profiles.profile import UserProfile
import brand
from .. import logger
class TwitchLinkStatic(LeoUI):
"""
Static UI whose only job is to display a persistent button
to ask people to connect their twitch account.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._embed: Optional[discord.Embed] = None
print("INITIALISATION")
async def interaction_check(self, interaction: discord.Interaction):
return True
@property
def embed(self) -> discord.Embed:
"""
This is the persistent message people will see with the button that starts the Oauth flow.
Not sure what this should actually say, or whether it should be customisable via command.
:TODO-MARKER:
"""
embed = discord.Embed(
title="Link your Twitch account!",
description=(
"To participate in the Dreamspace Adventure Game :TM:, "
"please start by pressing the button below to begin the login flow with Twitch!"
),
colour=brand.ACCENT_COLOUR,
)
return embed
@embed.setter
def embed(self, value):
self._embed = value
@button(label="Connect", custom_id="BTN-LINK-TWITCH", style=ButtonStyle.green, emoji='🔗')
@log_wrap(action="link-twitch-btn")
async def button_linker(self, interaction: discord.Interaction, btn: Button):
# Here we just reply to the interaction with the AuthFlow UI
# TODO
print("RESPONDING")
flowui = TwitchLinkFlow(interaction.client, interaction.user, callerid=interaction.user.id)
await flowui.run(interaction)
await flowui.wait()
class FlowState(IntEnum):
SETUP = -1
WAITING = 0
CANCELLED = 1
TIMEOUT = 2
ERRORED = 3
WORKING = 9
DONE = 10
class TwitchLinkFlow(MessageUI):
def __init__(self, bot: LionBot, caller: discord.User | discord.Member, *args, **kwargs):
kwargs.setdefault('callerid', caller.id)
super().__init__(*args, **kwargs)
self.bot = bot
self._auth_task = None
self._stage: FlowState = FlowState.SETUP
self.flow = None
self.authrow = None
self.user = caller
self._info = None
self._migration_details = None
# ----- UI API -----
async def run(self, interaction: discord.Interaction, **kwargs):
await interaction.response.defer(ephemeral=True, thinking=True)
await self._start_flow()
await self.draw(interaction, **kwargs)
if self._stage is FlowState.ERRORED:
# This can happen if starting the flow failed
await self.close()
@log_wrap(action="start-twitch-flow-ui")
async def _start_flow(self):
logger.info(f"Starting twitch authentication flow for {self.user}")
try:
self.flow = await self.bot.get_cog('TwitchAuthCog').start_auth()
except aiohttp.ClientError:
self._stage = FlowState.ERRORED
self._info = (
"Could not establish a connection to the authentication server! "
"Please try again later~"
)
logger.exception("Unexpected exception while starting authentication flow!", exc_info=True)
else:
self._stage = FlowState.WAITING
self._auth_task = asyncio.create_task(self._auth_flow())
@log_wrap(action="run-twitch-flow-ui")
async def _auth_flow(self):
"""
Run the flow and wait for a timeout, cancellation, or callback.
Update the message accordingly.
"""
assert self.flow is not None
try:
# TODO: Cancel this in cleanup
authrow = await asyncio.wait_for(self.flow.run(), timeout=60)
except asyncio.TimeoutError:
self._stage = FlowState.TIMEOUT
# Link Timed Out!
self._info = (
"We didn't receive a response so we closed the uplink "
"to keep your account safe! If you still want to connect, please try again!"
)
await self.refresh()
await self.close()
except asyncio.CancelledError:
# Presumably the user exited or the bot is shutting down.
# Not safe to edit the message, but try and cleanup
await self.close()
except SafeCancellation as e:
logger.info("User or server cancelled authentication flow: ", exc_info=True)
# Uplink Cancelled!
self._info = (
f"We couldn't complete the uplink!\nReason:*{e.msg}*"
)
self._stage = FlowState.CANCELLED
await self.refresh()
await self.close()
except Exception:
logger.exception("Something unexpected went wrong while running the flow!")
else:
self._stage = FlowState.WORKING
self._info = (
"Authentication complete! Connecting your Dreamspace account ...."
)
self.authrow = authrow
await self.refresh()
await self._link_twitch(str(authrow.userid))
await self.refresh()
await self.close()
async def cleanup(self):
await super().cleanup()
if self._auth_task and not self._auth_task.cancelled():
self._auth_task.cancel()
async def _link_twitch(self, twitch_id: str):
"""
Link the caller's profile to the given twitch_id.
Performs migration if needed.
"""
try:
twitch_user = await first(self.bot.twitch.get_users(user_ids=[twitch_id]))
except Exception:
logger.exception(
f"Looking up user {self.authrow} from Twitch authentication flow raised an error."
)
self._stage = FlowState.ERRORED
self._info = "Failed to look up your user details from Twitch! Please try again later."
return
if twitch_user is None:
logger.error(
f"User {self.authrow} obtained from Twitch authentication does not exist."
)
self._stage = FlowState.ERRORED
self._info = "Authentication failed! Please try again later."
return
profiles = self.bot.get_cog('ProfileCog')
userid = self.user.id
caller_profile = await UserProfile.fetch_from_discordid(self.bot, userid)
twitch_profile = await UserProfile.fetch_from_twitchid(self.bot, twitch_id)
succ_info = (
f"Successfully established uplink to your Twitch account **{twitch_user.display_name}** "
"and transferred dreamspace data! Happy adventuring, and watch out for the grue~"
)
# ::TODO-MARKER::
if twitch_profile is None:
if caller_profile is None:
# Neither profile exists
profile = await UserProfile.create_from_discord(self.bot, self.user)
await profile.attach_twitch(twitch_id)
self._stage = FlowState.DONE
self._info = succ_info
else:
await caller_profile.attach_twitch(twitch_id)
self._stage = FlowState.DONE
self._info = succ_info
else:
if caller_profile is None:
await twitch_profile.attach_discord(self.user.id)
self._stage = FlowState.DONE
self._info = succ_info
elif twitch_profile.profileid == caller_profile.profileid:
self._stage = FlowState.CANCELLED
self._info = (
f"The Twitch account **{twitch_user.display_name}** is already linked to your profile!"
)
else:
# In this case we have conflicting profiles we need to migrate
try:
results = await profiles.migrate_profile(twitch_profile, caller_profile)
except Exception:
self._stage = FlowState.ERRORED
self._info = (
"An issue was encountered while merging your account profiles! "
"The migration was rolled back, and not data has been lost.\n"
"The developer has been notified, please try again later!"
)
logger.exception(f"Failed to migrate profiles {twitch_profile=} to {caller_profile=}")
else:
self._stage = FlowState.DONE
self._info = succ_info
self._migration_details = '\n'.join(results)
logger.info(
f"Migrated {twitch_profile=} to {caller_profile}. Info: {self._migration_details}"
)
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
if self._stage is FlowState.SETUP:
raise ValueError("Making message before flow initialisation!")
assert self.flow is not None
if self._stage is FlowState.WAITING:
# Message should be the initial request page
dur = discord.utils.format_dt(utc_now() + timedelta(seconds=60), style='R')
title = "Press the button to login!"
desc = (
"We have generated a custom secure link for you to connect your Twitch profile! "
"Press the button below and accept the connection in your browser, "
"and we will begin the transfer!\n"
f"(Note: The link expires {dur})"
)
colour = brand.ACCENT_COLOUR
elif self._stage is FlowState.CANCELLED:
# Show cancellation message
# Show 'you can close this'
title = "Uplink Cancelled!"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.TIMEOUT:
title = "Link Timed Out"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.ERRORED:
title = "Something went wrong!"
desc = self._info
colour = discord.Colour.brand_red()
elif self._stage is FlowState.WORKING:
# We've received the auth, we are now doing migration
title = "Establishing Connection"
desc = self._info
colour = brand.ACCENT_COLOUR
elif self._stage is FlowState.DONE:
title = "Success!"
desc = self._info
colour = discord.Colour.brand_green()
else:
raise ValueError(f"Invalid stage value {self._stage}")
embed = discord.Embed(title=title, description=desc, colour=colour, timestamp=utc_now())
if self._migration_details:
embed.add_field(
name="Profile migration details",
value=self._migration_details
)
return MessageArgs(embed=embed)
async def refresh_layout(self):
# If we haven't received the auth callback yet, make the flow link button
if self.flow is None:
raise ValueError("Refreshing before flow initialisation!")
if self._stage <= FlowState.WAITING:
flow_link = self.flow.auth.return_auth_url()
button = Button(
style=ButtonStyle.link,
url=flow_link,
label="Login With Twitch"
)
self.set_layout((button,))
else:
self.set_layout(())
async def reload(self):
pass

16
src/routes/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
from .stamps import routes as stamp_routes
from .documents import routes as doc_routes
from .users import routes as user_routes
from .specimens import routes as spec_routes
from .transactions import routes as txn_routes
from .events import routes as event_routes
from .lib import dbvar, datamodelsv, profiledatav
def register_routes(router):
router.add_routes(stamp_routes)
router.add_routes(doc_routes)
router.add_routes(user_routes)
router.add_routes(spec_routes)
router.add_routes(event_routes)
router.add_routes(txn_routes)

View File

@@ -4,14 +4,17 @@
- `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set. - `/documents/{document_id}/stamps` which is passed to `/stamps` with `document_id` set.
""" """
import logging import logging
import binascii
from datetime import datetime from datetime import datetime
from typing import Any, NamedTuple, Optional, Self, TypedDict, Unpack, reveal_type, List from typing import Any, NamedTuple, Optional, Self, TypedDict, Unpack, reveal_type, List
from aiohttp import web from aiohttp import web
import discord
from data import Condition, condition from data import Condition, condition
from data.queries import JOINTYPE from data.queries import JOINTYPE
from datamodels import DataModel from datamodels import DataModel
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv from .lib import ModelField, datamodelsv, event_log
from .stamps import Stamp, StampCreateParams, StampEditParams, StampPayload from .stamps import Stamp, StampCreateParams, StampEditParams, StampPayload
routes = web.RouteTableDef() routes = web.RouteTableDef()
@@ -153,12 +156,76 @@ class Document:
stamp = await Stamp.create(app, **stampdata) stamp = await Stamp.create(app, **stampdata)
stamps.append(stamp) stamps.append(stamp)
return cls(app, row) self = cls(app, row)
# await self.log_create()
return self
async def get_stamps(self) -> List[Stamp]: async def get_stamps(self) -> List[Stamp]:
stamprows = await self.data.DocumentStamp.table.fetch_rows_where(document_id=self.row.document_id).order_by('stamp_id') stamprows = await self.data.DocumentStamp.table.fetch_rows_where(document_id=self.row.document_id).order_by('stamp_id')
return [Stamp(self.app, row) for row in stamprows] return [Stamp(self.app, row) for row in stamprows]
async def log_create(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Document #{self.row.document_id} Created!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def log_edit(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Document #{self.row.document_id} Updated!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def event_log_args(self, with_image=True) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
embed.set_footer(text='Created At')
args: dict = {'embed': embed}
if with_image:
try:
imagedata = self.row.to_bytes()
imagedata.seek(0)
embed.set_image(url='attachment://document.png')
args['files'] = [discord.File(imagedata, "document.png")]
except binascii.Error:
# Could not decode base64
embed.add_field(name='Image', value="Could not decode document data!")
else:
embed.add_field(name='Image', value="Failed to send image!")
return MessageArgs(**args)
async def tabulate(self):
"""
Present the Document as a discord-readable table.
"""
stamps = await self.get_stamps()
typnames = []
if stamps:
typs = {stamp.row.stamp_type for stamp in stamps}
for typ in typs:
# Stamp types should be cached so this isn't expensive
typrow = await self.data.StampType.fetch(typ)
typnames.append(typrow.stamp_type_name)
table = {
'document_id': f"`{self.row.document_id}`",
'seal': str(self.row.seal),
'metadata': f"`{self.row.metadata}`" if self.row.metadata else "No metadata",
'stamps': ', '.join(f"`{name}`" for name in typnames) if typnames else "No stamps",
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
return tabulate(*table.items())
async def prepare(self) -> DocPayload: async def prepare(self) -> DocPayload:
stamps = await self.get_stamps() stamps = await self.get_stamps()
@@ -192,6 +259,7 @@ class Document:
for stampdata in new_stamps: for stampdata in new_stamps:
stampdata.setdefault('document_id', row.document_id) stampdata.setdefault('document_id', row.document_id)
await Stamp.create(self.app, **stampdata) await Stamp.create(self.app, **stampdata)
await self.log_edit()
async def delete(self) -> DocPayload: async def delete(self) -> DocPayload:
payload = await self.prepare() payload = await self.prepare()

View File

@@ -1,15 +1,18 @@
import binascii
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
from aiohttp import web from aiohttp import web
import discord
from data import Condition, condition from data import Condition, condition
from data.conditions import NULL from data.conditions import NULL
from data.queries import JOINTYPE from data.queries import JOINTYPE
from datamodels import DataModel, EventType from datamodels import DataModel, EventType
from modules.profiles.data import ProfileData from modules.profiles.data import ProfileData
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv, dbvar, profiledatav from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
from .specimens import Specimen, SpecimenPayload from .specimens import Specimen, SpecimenPayload
routes = web.RouteTableDef() routes = web.RouteTableDef()
@@ -72,7 +75,7 @@ class Event:
conds.append(EventD.occurred_at >= after) conds.append(EventD.occurred_at >= after)
if event_type is not None: if event_type is not None:
ekey = (event_type.lower().strip(),) ekey = (event_type.lower().strip(),)
if ekey not in EventType: if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{event_type}'") raise web.HTTPBadRequest(text=f"Unknown event type '{event_type}'")
conds.append(EventD.event_type == EventType(ekey)) conds.append(EventD.event_type == EventType(ekey))
@@ -85,7 +88,7 @@ class Event:
raise web.HTTPBadRequest(text="Event creation missing required field 'event_type'.") raise web.HTTPBadRequest(text="Event creation missing required field 'event_type'.")
ekey = (params['event_type'].lower().strip(),) ekey = (params['event_type'].lower().strip(),)
if ekey not in EventType: if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{params['event_type']}'") raise web.HTTPBadRequest(text=f"Unknown event type '{params['event_type']}'")
event_type = EventType(ekey) event_type = EventType(ekey)
@@ -128,7 +131,7 @@ class Event:
# EventD = data.EventDetails # EventD = data.EventDetails
ekey = (kwargs['event_type'].lower().strip(),) ekey = (kwargs['event_type'].lower().strip(),)
if ekey not in EventType: if ekey not in [e.value for e in EventType]:
raise web.HTTPBadRequest(text=f"Unknown event type '{kwargs['event_type']}'") raise web.HTTPBadRequest(text=f"Unknown event type '{kwargs['event_type']}'")
event_type = EventType(ekey) event_type = EventType(ekey)
@@ -146,7 +149,7 @@ class Event:
typparams['message'] = kwargs.get('message') typparams['message'] = kwargs.get('message')
case EventType.RAID: case EventType.RAID:
typtab = data.raid_events typtab = data.raid_events
typparams['viewer_count'] = kwargs.get('viewer_count') typparams['visitor_count'] = kwargs.get('viewer_count')
case EventType.SUBSCRIBER: case EventType.SUBSCRIBER:
typtab = data.subscriber_events typtab = data.subscriber_events
typparams['tier'] = kwargs['tier'] typparams['tier'] = kwargs['tier']
@@ -196,8 +199,76 @@ class Event:
details = await data.EventDetails.fetch(eventrow.event_id) details = await data.EventDetails.fetch(eventrow.event_id)
assert details is not None assert details is not None
return cls(app, details) self = cls(app, details)
await self.log_create()
return self
async def log_create(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Event #{self.row.event_id} Created!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def log_edit(self, with_image=True):
args = await self.event_log_args(with_image=with_image)
args.kwargs['embed'].title = f"Event #{self.row.event_id} Updated!"
try:
await event_log(**args.send_args)
except discord.HTTPException:
if with_image:
# Try again without the image in case that was the issue
await self.log_create(with_image=False)
async def event_log_args(self, with_image=True) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at)
embed.set_footer(text='Created At')
args: dict = {'embed': embed}
doc = await self.get_document()
if doc is not None:
embed.add_field(
name="Document",
value='\n'.join(await doc.tabulate()),
inline=False
)
if with_image:
try:
imagedata = doc.row.to_bytes()
imagedata.seek(0)
embed.set_image(url='attachment://document.png')
args['files'] = [discord.File(imagedata, "document.png")]
except binascii.Error:
# Could not decode base64
embed.add_field(name='Image', value="Could not decode document data!")
else:
embed.add_field(name='Image', value="Failed to send image!")
return MessageArgs(**args)
async def tabulate(self):
"""
Present the Event as a discord-readable table.
"""
user = await self.get_user()
assert user is not None
table = {
'event_id': f"`{self.row.event_id}`",
'event_type': f"`{self.row.event_type}`",
'user': f"`{self.row.user_id}` (`{self.row.user_name}`)",
'document': f"`{self.row.document_id}`",
'occurred_at': discord.utils.format_dt(self.row.occurred_at, 'F'),
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
info = self.row.event_type.info()
for col, param in zip(info.detailcolumns, info.params):
value = getattr(self.row, col)
table[param] = f"`{value}`"
return tabulate(*table.items())
async def edit(self, **kwargs): async def edit(self, **kwargs):
data = self.data data = self.data
@@ -220,9 +291,8 @@ class Event:
typparams[key] = kwargs[key] typparams[key] = kwargs[key]
case EventType.RAID: case EventType.RAID:
typtab = data.raid_events typtab = data.raid_events
for key in ('viewer_count',): if 'viewer_count' in kwargs:
if key in kwargs: typparams['visitor_count'] = 'viewer_count'
typparams[key] = kwargs[key]
case EventType.SUBSCRIBER: case EventType.SUBSCRIBER:
typtab = data.subscriber_events typtab = data.subscriber_events
for key in ('tier', 'subscribed_length', 'message'): for key in ('tier', 'subscribed_length', 'message'):
@@ -231,6 +301,7 @@ class Event:
if typparams: if typparams:
await typtab.update_where(event_id=self.row.event_id).set(**typparams) await typtab.update_where(event_id=self.row.event_id).set(**typparams)
await self.log_edit()
await self.row.refresh() await self.row.refresh()
async def delete(self): async def delete(self):

View File

@@ -1,9 +1,11 @@
from typing import NamedTuple, Any, Optional, Self, Unpack, List, TypedDict from typing import NamedTuple, Any, Optional, Self, Unpack, List, TypedDict
from aiohttp import web from aiohttp import web, ClientSession
from discord import Webhook
from data.database import Database from data.database import Database
from datamodels import DataModel from datamodels import DataModel
from modules.profiles.data import ProfileData from modules.profiles.data import ProfileData
from meta import conf
dbvar = web.AppKey("database", Database) dbvar = web.AppKey("database", Database)
datamodelsv = web.AppKey("datamodels", DataModel) datamodelsv = web.AppKey("datamodels", DataModel)
@@ -18,32 +20,10 @@ class ModelField(NamedTuple):
can_edit: bool can_edit: bool
class ModelClassABC[RowT, Payload: TypedDict, CreateParams: TypedDict, EditParams: TypedDict]: async def event_log(*args, **kwargs):
def __init__(self, app: web.Application, row: RowT): # Post the given message to the configured event log, if set
self.app = app event_log_url = conf.api.get('EVENTLOG')
self.data = app[datamodelsv] if event_log_url:
self.row = row async with ClientSession() as session:
webhook = Webhook.from_url(event_log_url, session=session)
@classmethod await webhook.send(**kwargs)
async def fetch_from_id(cls, app: web.Application, document_id: int) -> Optional[Self]:
...
@classmethod
async def query(
cls,
**kwargs
) -> List[Self]:
...
@classmethod
async def create(cls, app: web.Application, **kwargs: Unpack[CreateParams]) -> Self:
...
async def prepare(self) -> Payload:
...
async def edit(self, **kwargs: Unpack[EditParams]):
...
async def delete(self) -> Payload:
...

View File

@@ -11,14 +11,16 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List from typing import Any, Literal, NamedTuple, Optional, Self, TypedDict, Unpack, overload, reveal_type, List
from aiohttp import web from aiohttp import web
import discord
from data import Condition, condition from data import Condition, condition
from data.conditions import NULL from data.conditions import NULL
from data.queries import JOINTYPE from data.queries import JOINTYPE
from datamodels import DataModel from datamodels import DataModel
from modules.profiles.data import ProfileData from modules.profiles.data import ProfileData
from utils.lib import MessageArgs, tabulate
from .lib import ModelField, datamodelsv, dbvar, profiledatav from .lib import ModelField, datamodelsv, dbvar, event_log, profiledatav
from .specimens import Specimen, SpecimenPayload from .specimens import Specimen, SpecimenPayload
routes = web.RouteTableDef() routes = web.RouteTableDef()
@@ -141,6 +143,7 @@ class User:
raise web.HTTPBadRequest(text="Invalid 'twitch_id' passed to user creation!") raise web.HTTPBadRequest(text="Invalid 'twitch_id' passed to user creation!")
# First check if the profile already exists by querying the Dreamer database # First check if the profile already exists by querying the Dreamer database
edited = 0 # 0 means not edited, 1 means created, 2 means modified
rows = await data.Dreamer.fetch_where(twitch_id=twitch_id) rows = await data.Dreamer.fetch_where(twitch_id=twitch_id)
if rows: if rows:
logger.debug(f"Updating Dreamer for {twitch_id=} with {kwargs}") logger.debug(f"Updating Dreamer for {twitch_id=} with {kwargs}")
@@ -150,6 +153,7 @@ class User:
if dreamer.preferences is None and dreamer.name is None: if dreamer.preferences is None and dreamer.name is None:
await data.UserPreferences.fetch_or_create(dreamer.user_id, twitch_name=name, preferences=prefs) await data.UserPreferences.fetch_or_create(dreamer.user_id, twitch_name=name, preferences=prefs)
dreamer = await dreamer.refresh() dreamer = await dreamer.refresh()
edited = 2
# Now compare the existing data against the provided data and update if needed # Now compare the existing data against the provided data and update if needed
if name != dreamer.name: if name != dreamer.name:
@@ -159,6 +163,7 @@ class User:
q.set(preferences=prefs) q.set(preferences=prefs)
await q await q
dreamer = await dreamer.refresh() dreamer = await dreamer.refresh()
edited = 2
else: else:
# Create from scratch # Create from scratch
logger.info(f"Creating Dreamer for {twitch_id=} with {kwargs}") logger.info(f"Creating Dreamer for {twitch_id=} with {kwargs}")
@@ -176,8 +181,16 @@ class User:
) )
dreamer = await data.Dreamer.fetch(user_profile.profileid) dreamer = await data.Dreamer.fetch(user_profile.profileid)
assert dreamer is not None assert dreamer is not None
edited = 1
return cls(app, dreamer) self = cls(app, dreamer)
if edited == 1:
args = await self.event_log_args(title=f"User #{dreamer.user_id} created!")
await event_log(**args.send_args)
elif edited == 2:
args = await self.event_log_args(title=f"User #{dreamer.user_id} updated!")
await event_log(**args.send_args)
return self
async def edit(self, **kwargs: Unpack[UserEditParams]): async def edit(self, **kwargs: Unpack[UserEditParams]):
data = self.data data = self.data
@@ -193,9 +206,15 @@ class User:
logger.info(f"Updating dreamer {self.row=} with {kwargs}") logger.info(f"Updating dreamer {self.row=} with {kwargs}")
await prefs.update(**update_args) await prefs.update(**update_args)
args = await self.event_log_args(title=f"User #{self.row.user_id} updated!")
await event_log(**args.send_args)
async def delete(self) -> UserDetailsPayload: async def delete(self) -> UserDetailsPayload:
payload = await self.prepare(details=True) payload = await self.prepare(details=True)
await self.row.delete() # This will cascade to all other data the user has
await self.profile_data.UserProfileRow.table.delete_where(profileid=self.row.user_id)
# Make sure we take the user out of cache
await self.row.refresh()
return payload return payload
async def get_wallet(self): async def get_wallet(self):
@@ -222,6 +241,28 @@ class User:
async def get_inventory(self): async def get_inventory(self):
return [] return []
async def event_log_args(self, **kwargs) -> MessageArgs:
desc = '\n'.join(await self.tabulate())
embed = discord.Embed(description=desc, timestamp=self.row.created_at, **kwargs)
embed.set_footer(text='Created At')
# TODO: We could add wallet, specimen, and inventory info here too
return MessageArgs(embed=embed)
async def tabulate(self):
"""
Present the User as a discord-readable table.
"""
table = {
'user_id': f"`{self.row.user_id}`",
'twitch_id': f"`{self.row.twitch_id}`" if self.row.twitch_id else 'No Twitch linked',
'name': f"`{self.row.name}`",
'preferences': f"`{self.row.preferences}`",
'created_at': discord.utils.format_dt(self.row.created_at, 'F'),
}
return tabulate(*table.items())
@overload @overload
async def prepare(self, details: Literal[True]=True) -> UserDetailsPayload: async def prepare(self, details: Literal[True]=True) -> UserDetailsPayload:
... ...