44 Commits

Author SHA1 Message Date
1e7a5c9b8a fix(tasklist): Add profile migrator. 2024-12-03 08:43:03 +10:00
592017ba5e feat (tasklist): Add nowlist and plan tables. 2024-12-03 08:19:44 +10:00
49a8cefeef feat (tasklist): Migration to new profile tables. 2024-11-30 15:48:15 +10:00
d4870740a2 Merge branch 'feat-nowlist-profiles' into staging 2024-10-08 11:18:53 +10:00
8991b1a641 feat(nowlist): Implement edit cmd. 2024-10-07 12:19:03 +10:00
79645177bd feat(timer): Add new current task. 2024-10-07 12:18:43 +10:00
9b3b7265d3 fix(nowlist): Added reload after migration. 2024-10-07 11:57:40 +10:00
3c0d527501 feat(nowlist): Task profile migration. 2024-10-07 11:48:45 +10:00
997804c6bf feat(nows): Moved tasklist to profiles. 2024-10-07 01:07:56 +10:00
2cdd084bbe Merge branch 'feat-profiles' into staging 2024-10-06 21:40:56 +10:00
72d52b6014 feat(profiles): Add community profiles. 2024-10-06 21:38:09 +10:00
92fee23afa feat(profiles): Add profile base and users. 2024-10-06 15:43:49 +10:00
83a63e8a6e Merge branch 'staging' into feat-profiles 2024-10-05 08:01:48 +10:00
63152f3475 routine: Use ssh url for voicefix submodule. 2024-10-05 07:50:43 +10:00
81e25e7efc feat(vcroles): Add voice autoroles. 2024-10-05 04:07:46 +10:00
ce07f7ae73 Merge branch 'feat-target-seeker' into staging 2024-10-05 03:17:58 +10:00
d158aed257 feat: Add target seeker. 2024-10-05 03:17:32 +10:00
47a52d9600 routine: Update voicefix pointer. 2024-10-05 01:42:15 +10:00
fc459ac0dd fix(counters): Fix counter response period user. 2024-09-30 19:02:06 +10:00
45b57b4eca (counters): Remove outdated comment. 2024-09-28 16:58:40 +10:00
22b99717db Merge branch 'feat-counter-refactor' into staging 2024-09-28 15:42:13 +10:00
2810365588 feat(counters): Dynamic counter aliases. 2024-09-28 15:39:24 +10:00
e9946a9814 feat(cog): Attach single twitchIO command to cog. 2024-09-28 15:38:13 +10:00
8f6fdf3381 feat(counters): Add details and initial refactor. 2024-09-27 00:36:12 +10:00
9d0d19d046 (profiles): Start internal API. 2024-09-26 21:22:42 +10:00
a7eb8d0f09 Merge branch 'feat-auth' into feat-profiles 2024-09-26 01:49:55 +10:00
9c738ecb91 feat(twitch): Add basic user authentication flow. 2024-09-26 01:48:24 +10:00
9c9107bf9d fix(timers): Remove user from last_seen on leave.
Fixes an issue where user inactivity was inaccurately tracked on rejoin.
2024-09-26 01:46:39 +10:00
caa907b6d9 feat(twitch): Add UserAuthFlow for user auth. 2024-09-23 15:56:51 +10:00
44d6d77494 feat(twitch): Add authentication server. 2024-09-23 15:56:18 +10:00
f2c449d2e0 feat (timer): Streamtimer channel editing 2024-09-15 15:15:54 +10:00
53366c0333 (voice): Adjust now responses. 2024-09-15 14:10:29 +10:00
66f7680482 feat (voice): Loosen now cmd restrictions. 2024-09-15 12:24:40 +10:00
37f25f10ef Merge branch 'timerlayout' into staging 2024-09-15 11:53:39 +10:00
87488eaf99 feat (timer): Add support for new timer layout. 2024-09-15 11:52:56 +10:00
7d327a5e2f Merge branch 'staging' into feat-auth 2024-09-08 16:58:43 +10:00
6d59ab0a34 Merge branch 'feat-cog-refactor' into staging 2024-09-08 16:57:15 +10:00
970661fe05 (tags): Migrated to merged LionCog. 2024-09-08 16:55:55 +10:00
75ab3d58cb (shoutouts): Migrate to merged LionCog. 2024-09-08 16:47:59 +10:00
85c7aeb3b6 (nowdoing): Migrate to merged LionCog. 2024-09-08 16:35:37 +10:00
99bb1958a8 (LionCog): Split check types. 2024-09-08 16:17:34 +10:00
41f755795f feat: Start twitch user auth module. 2024-09-06 10:59:47 +10:00
bc073363b9 feat: Start merged profiles and communities. 2024-09-06 10:59:13 +10:00
b7e4acfee2 Merge disc and twitchio Cogs. 2024-09-06 10:57:07 +10:00
48 changed files with 2434 additions and 430 deletions

6
.gitmodules vendored
View File

@@ -1,12 +1,12 @@
[submodule "bot/gui"]
path = src/gui
url = https://github.com/StudyLions/StudyLion-Plugin-GUI.git
url = git@github.com:Intery/CafeHelper-GUI.git
[submodule "skins"]
path = skins
url = https://github.com/Intery/pillow-skins.git
url = git@github.com:Intery/CafeHelper-Skins.git
[submodule "src/modules/voicefix"]
path = src/modules/voicefix
url = https://github.com/Intery/StudyLion-voicefix.git
url = git@github.com:Intery/StudyLion-voicefix.git
[submodule "src/modules/streamalerts"]
path = src/modules/streamalerts
url = https://github.com/Intery/StudyLion-streamalerts.git

View File

@@ -287,13 +287,14 @@ CREATE TABLE tasklist(
deleted_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ,
last_updated_at TIMESTAMPTZ
last_updated_at TIMESTAMPTZ,
duration INTEGER
);
CREATE INDEX tasklist_users ON tasklist (userid);
ALTER TABLE tasklist
ADD CONSTRAINT fk_tasklist_users
FOREIGN KEY (userid)
REFERENCES user_config (userid)
REFERENCES user_profiles (profileid)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE tasklist
@@ -317,6 +318,20 @@ CREATE TABLE tasklist_reward_history(
reward_count INTEGER
);
CREATE INDEX tasklist_reward_history_users ON tasklist_reward_history (userid, reward_time);
CREATE TABLE tasklist_current(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
started_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE tasklist_planner(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
sortkey INTEGER
);
-- }}}
-- Reminder data {{{
@@ -1454,6 +1469,7 @@ CREATE TABLE shoutouts(
CREATE TABLE counters(
counterid SERIAL PRIMARY KEY,
name TEXT NOT NULL,
category TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX counters_name ON counters (name);
@@ -1464,6 +1480,7 @@ CREATE TABLE counter_log(
userid INTEGER NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
details TEXT,
context_str TEXT
);
CREATE INDEX counter_log_counterid ON counter_log (counterid);
@@ -1484,6 +1501,98 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
-- }}}
-- Voice Roles {{{
CREATE TABLE voice_roles(
voice_role_id SERIAL PRIMARY KEY,
channelid BIGINT NOT NULL,
roleid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX voice_role_channels on voice_roles (channelid);
-- }}}
-- User and Community Profiles {{{
DROP TABLE IF EXISTS community_members;
DROP TABLE IF EXISTS communities_twitch;
DROP TABLE IF EXISTS communities_discord;
DROP TABLE IF EXISTS communities;
DROP TABLE IF EXISTS profiles_twitch;
DROP TABLE IF EXISTS profiles_discord;
DROP TABLE IF EXISTS user_profiles;
CREATE TABLE user_profiles(
profileid SERIAL PRIMARY KEY,
nickname TEXT,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE profiles_discord(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
CREATE TABLE profiles_twitch(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
CREATE TABLE communities(
communityid SERIAL PRIMARY KEY,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE communities_discord(
guildid BIGINT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX communities_discord_communityid ON communities_discord (communityid);
CREATE TABLE communities_twitch(
channelid TEXT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX communities_twitch_communityid ON communities_twitch (communityid);
CREATE TABLE community_members(
memberid SERIAL PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
-- }}}
-- Twitch User Auth {{{
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
-- }}}
-- Analytics Data {{{

2
skins

Submodule skins updated: d3d6a28bc9...686857321e

View File

@@ -80,6 +80,14 @@ async def main():
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
)
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
)
lionbot = await stack.enter_async_context(
LionBot(
command_prefix='!',
@@ -90,6 +98,7 @@ async def main():
config=conf,
initial_extensions=[
'utils', 'core', 'analytics',
'twitch',
'modules',
'babel',
'tracking.voice', 'tracking.text',
@@ -104,26 +113,15 @@ async def main():
translator=translator,
chunk_guilds_at_startup=False,
system_monitor=system_monitor,
crocbot=crocbot,
)
)
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
lionbot=lionbot
)
lionbot.crocbot = crocbot
crocbot.load_module('modules')
crocstart = asyncio.create_task(start_croccy(crocbot))
lionstart = asyncio.create_task(start_lion(lionbot))
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
crocstart.cancel()
lionstart.cancel()
# crocstart.cancel()
# lionstart.cancel()
async def start_lion(lionbot):
ctx_bot.set(lionbot)

Submodule src/gui updated: c1bcb05c25...62d2484914

View File

@@ -1,19 +1,18 @@
from collections import defaultdict
from typing import TYPE_CHECKING
import logging
import twitchio
from twitchio.ext import commands
from twitchio.ext import pubsub
from twitchio.ext.commands.core import itertools
from data import Database
from .config import Conf
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
@@ -21,12 +20,57 @@ class CrocBot(commands.Bot):
def __init__(self, *args,
config: Conf,
data: Database,
lionbot: 'LionBot', **kwargs):
**kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data = data
self.pubsub = pubsub.PubSubPool(self)
self.lionbot = lionbot
self._member_cache = defaultdict(dict)
async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
self._member_cache[channel.name][user.name] = user
async def event_message(self, message: twitchio.Message):
if message.channel and message.author:
self._member_cache[message.channel.name][message.author.name] = message.author
await self.handle_commands(message)
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
if userstr.startswith('@'):
matching = False
userstr = userstr.strip('@ ')
result = None
if matching and len(userstr) >= 3:
lowered = userstr.lower()
full_matches = []
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
matchstr = user.name.lower()
print(matchstr)
if matchstr.startswith(lowered):
result = user
break
if lowered in matchstr:
full_matches.append(user)
if result is None and full_matches:
result = full_matches[0]
print(result)
if result is None:
lookup = userstr
elif result.id is None:
lookup = result.name
else:
lookup = None
if lookup:
found = await self.fetch_users(names=[lookup])
if found:
result = found[0]
# No matches found
return result

View File

@@ -24,8 +24,10 @@ from .errors import HandledException, SafeCancellation
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
if TYPE_CHECKING:
from meta.CrocBot import CrocBot
from core.cog import CoreCog
from core.config import ConfigCog
from twitch.cog import TwitchAuthCog
from tracking.voice.cog import VoiceTrackerCog
from tracking.text.cog import TextTrackerCog
from modules.config.cog import GuildConfigCog
@@ -48,6 +50,7 @@ if TYPE_CHECKING:
from modules.topgg.cog import TopggCog
from modules.user_config.cog import UserConfigCog
from modules.video_channels.cog import VideoCog
from modules.profiles.cog import ProfileCog
logger = logging.getLogger(__name__)
@@ -58,6 +61,7 @@ class LionBot(Bot):
initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [],
system_monitor: Optional[SystemMonitor] = None,
crocbot: Optional['CrocBot'] = None,
**kwargs
):
kwargs.setdefault('tree_cls', LionTree)
@@ -73,6 +77,8 @@ class LionBot(Bot):
self.app_ipc = app_ipc
self.translator = translator
self.crocbot = crocbot
self.system_monitor = system_monitor or SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor)
@@ -86,6 +92,10 @@ class LionBot(Bot):
def core(self):
return self.get_cog('CoreCog')
@property
def profiles(self):
return self.get_cog('ProfileCog')
async def _handle_global_dispatch(self, event_name: str, *args, **kwargs):
self.dispatch(event_name, *args, **kwargs)
@@ -138,6 +148,10 @@ class LionBot(Bot):
# To make the type checker happy about fetching cogs by name
# TODO: Move this to stubs at some point
@overload
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
...
@overload
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
...
@@ -150,6 +164,10 @@ class LionBot(Bot):
def get_cog(self, name: Literal['VoiceTrackerCog']) -> 'VoiceTrackerCog':
...
@overload
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
...
@overload
def get_cog(self, name: Literal['TextTrackerCog']) -> 'TextTrackerCog':
...

View File

@@ -1,23 +1,37 @@
from typing import Any
from functools import partial
from typing import Any, Callable, Optional
from discord.ext.commands import Cog
from discord.ext import commands as cmds
from twitchio.ext import commands
from twitchio.ext.commands import Command, Bot
from twitchio.ext.commands.meta import CogEvent
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
_twitch_cmds_: dict[str, Command]
_twitch_events_: dict[str, CogEvent]
_twitch_events_loaded_: set[Callable]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
cls._twitch_cmds_ = {}
cls._twitch_events_ = {}
cls._twitch_events_loaded_ = set()
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
elif isinstance(value, Command):
cls._twitch_cmds_[value.name] = value
elif isinstance(value, CogEvent):
cls._twitch_events_[value.name] = value
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
@@ -34,6 +48,72 @@ class LionCog(Cog):
return await super()._inject(bot, *args, *kwargs)
def add_twitch_command(self, bot: Bot, command: Command):
"""
Dynamically register a command with the given bot.
The command will be deregistered on cog unload.
"""
# Remove any conflicting commands
if cmd := bot.get_command(command.name):
bot.remove_command(cmd.name)
self._twitch_cmds_.pop(command.name, None)
try:
self._twitch_cmds_[command.name] = command
command._instance = self
command.cog = self
bot.add_command(command)
except Exception:
# Ensure the command doesn't die in the internal command cache
self._twitch_cmds_.pop(command.name, None)
raise
def _load_twitch_methods(self, bot: Bot):
for name, command in self._twitch_cmds_.items():
command._instance = self
command.cog = self
bot.add_command(command)
for name, event in self._twitch_events_.items():
callback = partial(event, self)
self._twitch_events_loaded_.add(callback)
bot.add_event(callback=callback, name=name)
def _unload_twitch_methods(self, bot: Bot):
for name in self._twitch_cmds_:
bot.remove_command(name)
for callback in self._twitch_events_loaded_:
bot.remove_event(callback=callback)
self._twitch_events_loaded_.clear()
@classmethod
def twitch_event(cls, event: Optional[str] = None):
def decorator(func) -> CogEvent:
event_name = event or func.__name__
return CogEvent(name=event_name, func=func, module=cls.__module__)
return decorator
async def cog_check(self, ctx): # type: ignore
"""
TwitchIO assumes cog_check is a coroutine,
so here we narrow the check to only a coroutine.
The ctx maybe either be a twitch command context or a dpy context.
"""
if isinstance(ctx, cmds.Context):
return await self.cog_check_discord(ctx)
if isinstance(ctx, commands.Context):
return await self.cog_check_twitch(ctx)
async def cog_check_discord(self, ctx: cmds.Context):
return True
async def cog_check_twitch(self, ctx: commands.Context):
return True
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True

View File

@@ -13,6 +13,8 @@ if TYPE_CHECKING:
from core.lion_member import LionMember
from core.lion_user import LionUser
from core.lion_guild import LionGuild
from modules.profiles.profile import UserProfile
from modules.profiles.community import Community
logger = logging.getLogger(__name__)
@@ -54,6 +56,8 @@ class LionContext(Context['LionBot']):
lguild: 'LionGuild'
lmember: 'LionMember'
alion: 'LionUser | LionMember'
profile: 'UserProfile'
community: 'Community'
def __repr__(self):
parts = {}

View File

@@ -2,6 +2,7 @@ this_package = 'modules'
active_discord = [
'.sysadmin',
'.profiles',
'.config',
'.user_config',
'.skins',
@@ -26,20 +27,13 @@ active_discord = [
'.premium',
'.streamalerts',
'.test',
]
active_twitch = [
'.counters',
'.nowdoing',
'.shoutouts',
'.counters',
'.tagstrings',
'.voiceroles',
]
def prepare(bot):
for ext in active_twitch:
bot.load_module(this_package + ext)
async def setup(bot):
for ext in active_discord:
await bot.load_extension(ext, package=this_package)

View File

@@ -4,10 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import CounterCog
def prepare(bot):
bot.add_cog(CounterCog(bot))
async def setup(bot):
from .lion_cog import CounterCog
await bot.add_cog(CounterCog(bot))

View File

@@ -3,11 +3,14 @@ from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import CrocBot
from meta import LionCog, LionBot, CrocBot
from utils.lib import utc_now
from . import logger
from .data import CounterData
@@ -22,10 +25,80 @@ class PERIOD(Enum):
YEAR = ('this year', 'y', 'year', 'yearly')
class CounterCog(commands.Cog):
def __init__(self, bot: CrocBot):
def counter_cmd_factory(
counter: str,
response: str,
default_period: Optional[PERIOD] = PERIOD.STREAM,
context: Optional[str] = None
):
context = context or f"cmd: {counter}"
async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await cog.parse_period(channelid, '', default=default_period)
args = (args or '').strip(" 󠀀 ")
splits = args.split(maxsplit=1)
splits = [split.strip() for split in splits if split]
details = None
amount = 1
if splits:
if splits[0].isdigit() or (splits[0].startswith('-') and splits[0][1:].isdigit()):
amount = int(splits[0])
splits = splits[1:]
if splits:
details = ' '.join(splits)
await cog.add_to_counter(
counter, userid, amount,
context=context,
details=details
)
lb = await cog.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(
response.format(
total=total,
period=period,
period_name=period.value[0],
detailsorname=details or counter,
user_total=user_total,
)
)
async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await cog.formatted_lb(counter, args, int(user.id)))
async def undo_cmd(cog, ctx: commands.Context):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
_counter = await cog.fetch_counter(counter)
query = cog.data.CounterEntry.fetch_where(
counterid=_counter.counterid,
userid=userid,
)
query.order_by('created_at', direction=ORDER.DESC)
query.limit(1)
results = await query
if not results:
await ctx.reply("Nothing to delete!")
else:
row = results[0]
await row.delete()
await ctx.reply("Undo successful!")
return (counter_cmd, lb_cmd, undo_cmd)
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(CounterData())
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(CounterData())
self.loaded = asyncio.Event()
@@ -33,9 +106,42 @@ class CounterCog(commands.Cog):
self.counters = {}
async def cog_load(self):
self._load_twitch_methods(self.crocbot)
await self.load_counter_commands()
await self.data.init()
await self.load_counters()
self.loaded.set()
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
async def load_counter_commands(self):
rows = await self.data.CounterCommand.fetch_where()
for row in rows:
counter = await self.data.Counter.fetch(row.counterid)
counter_cb, lb_cb, undo_cb = counter_cmd_factory(
counter.name,
row.response
)
cmds = []
main_cmd = commands.command(name=row.name)(counter_cb)
cmds.append(main_cmd)
if row.lbname:
lb_cmd = commands.command(name=row.lbname)(lb_cb)
cmds.append(lb_cmd)
if row.undoname:
undo_cmd = commands.command(name=row.undoname)(undo_cb)
cmds.append(undo_cmd)
for cmd in cmds:
self.add_twitch_command(self.crocbot, cmd)
logger.info(f"(Re)Loaded {len(rows)} counter commands!")
async def cog_check(self, ctx):
return True
async def load_counters(self):
"""
Initialise counter name cache.
@@ -46,18 +152,6 @@ class CounterCog(commands.Cog):
f"Loaded {len(self.counters)} counters."
)
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
# Counters API
async def fetch_counter(self, counter: str) -> CounterData.Counter:
@@ -79,13 +173,19 @@ class CounterCog(commands.Cog):
if row:
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
async def add_to_counter(
self,
counter: str, userid: int, value: int,
context: Optional[str]=None,
details: Optional[str]=None,
):
row = await self.fetch_counter(counter)
return await self.data.CounterEntry.create(
counterid=row.counterid,
userid=userid,
value=value,
context_str=context
context_str=context,
details=details
)
async def leaderboard(self, counter: str, start_time=None):
@@ -154,8 +254,43 @@ class CounterCog(commands.Cog):
elif subcmd == 'clear':
await self.reset_counter(name)
await ctx.reply(f"'{name}' counter reset.")
elif subcmd == 'alias':
splits = args.split(maxsplit=3) if args else []
counter = await self.fetch_counter(name)
rows = await self.data.CounterCommand.fetch_where(counterid=counter.counterid)
existing = rows[0] if rows else None
if existing and not args:
# Show current alias
await ctx.reply(
f"Counter '{name}' aliases: '!{existing.name}' to add to counter; "
f"'!{existing.lbname}' to view counter leaderboard; "
f"'!{existing.undoname}' to undo (your) last addition."
)
elif len(splits) < 4:
# Show usage
await ctx.reply(
"USAGE: !counter <name> alias <cmdname> <lbname> <undoname> <response> -- "
"Response accepts keywords {total}, {period}, {period_name}, {detailsorname}, {user_total}."
)
else:
# Create new alias
cmdname, lbname, undoname, response = splits
# Remove any existing alias
await self.data.CounterCommand.table.delete_where(name=cmdname)
alias = await self.data.CounterCommand.create(
name=cmdname,
counterid=counter.counterid,
lbname=lbname, undoname=undoname, response=response
)
await self.load_counter_commands()
await ctx.reply(
f"Alias created for counter '{name}': '!{alias.name}' to add to counter; "
f"'!{alias.lbname}' to view counter leaderboard; "
f"'!{alias.undoname}' to undo (your) last addition."
)
else:
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
if periodstr:
@@ -171,7 +306,7 @@ class CounterCog(commands.Cog):
if period is PERIOD.ALL:
start_time = None
elif period is PERIOD.STREAM:
streams = await self.bot.fetch_streams(user_ids=[userid])
streams = await self.crocbot.fetch_streams(user_ids=[userid])
if streams:
stream = streams[0]
start_time = stream.started_at
@@ -199,7 +334,7 @@ class CounterCog(commands.Cog):
lb = await self.leaderboard(counter, start_time=start_time)
if lb:
userids = list(lb.keys())
users = await self.bot.fetch_users(ids=userids)
users = await self.crocbot.fetch_users(ids=userids)
name_map = {user.id: user.display_name for user in users}
parts = []
for userid, total in lb.items():
@@ -210,90 +345,3 @@ class CounterCog(commands.Cog):
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
else:
return f"{counter} {period.value[-1]} leaderboard is empty!"
# Misc actual counter commands
# TODO: Factor this out to a different module...
@commands.command()
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'tea'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: tea'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
@commands.command()
async def tealb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
@commands.command()
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'coffee'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: coffee'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
@commands.command()
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
@commands.command()
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'water'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: water'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
@commands.command()
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
@commands.command()
async def reload(self, ctx: commands.Context, *, args: str = ''):
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
return
if not args:
await ctx.reply("Full reload not implemented yet.")
else:
try:
self.bot.reload_module(args)
except Exception:
logger.exception("Failed to reload")
await ctx.reply("Failed to reload module! Check console~")
else:
await ctx.reply("Reloaded!")

View File

@@ -10,7 +10,8 @@ class CounterData(Registry):
CREATE TABLE counters(
counterid SERIAL PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
category TEXT
);
CREATE UNIQUE INDEX counters_name ON counters (name);
"""
@@ -19,6 +20,7 @@ class CounterData(Registry):
counterid = Integer(primary=True)
name = String()
category = String()
created_at = Timestamp()
class CounterEntry(RowModel):
@@ -31,7 +33,8 @@ class CounterData(Registry):
userid INTEGER NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
context_str TEXT
context_str TEXT,
details TEXT
);
CREATE INDEX counter_log_counterid ON counter_log (counterid);
"""
@@ -44,5 +47,28 @@ class CounterData(Registry):
value = Integer()
created_at = Timestamp()
context_str = String()
details = String()
class CounterCommand(RowModel):
"""
Schema
------
CREATE TABLE counter_commands(
name TEXT PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
lbname TEXT,
undoname TEXT,
response TEXT NOT NULL
);
"""
# NOTE: This table will be replaced by aliases soon anyway
# So no need to worry about integrity or future-proofing
_tablename_ = 'counter_commands'
_cache_ = {}
name = String(primary=True)
counterid = Integer()
lbname = String()
undoname = String()
response = String()

View File

@@ -1,23 +0,0 @@
import asyncio
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from meta.logger import log_wrap
from utils.lib import utc_now
from data.conditions import NULL
from . import logger
from .data import CounterData
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.counter_cog = bot.crocbot.get_cog('CounterCog')

View File

@@ -0,0 +1,9 @@
ALTER TABLE counters ADD COLUMN category TEXT;
ALTER TABLE counter_log ADD COLUMN details TEXT;
CREATE TABLE counter_commands(
name TEXT PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
lbname TEXT,
undoname TEXT,
response TEXT NOT NULL
);

View File

@@ -4,6 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import NowDoingCog
def prepare(bot):
logger.info("Preparing the nowdoing module.")
bot.add_cog(NowDoingCog(bot))
async def setup(bot):
await bot.add_cog(NowDoingCog(bot))

View File

@@ -4,16 +4,21 @@ import json
import os
from typing import Optional
from attr import dataclass
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionCog, LionContext, LionBot
from meta.sockets import Channel, register_channel
from utils.lib import strfdelta, utc_now
from . import logger
from .data import NowListData
from modules.profiles.profile import UserProfile
class NowDoingChannel(Channel):
name = 'NowList'
@@ -24,19 +29,7 @@ class NowDoingChannel(Channel):
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
for task in self.cog.tasks.values():
await self.send_set(*self.task_args(task), websocket=websocket)
async def send_test_set(self):
tasks = [
(0, 'Tester0', "Testing Tasklist", True),
(1, 'Tester1', "Getting Confused", False),
(2, "Tester2", "Generating Bugs", True),
(3, "Tester3", "Fixing Bugs", False),
(4, "Tester4", "Pushing the red button", False),
]
for task in tasks:
await self.send_set(*task)
await self.reload_tasklist(websocket=websocket)
def task_args(self, task: NowListData.Task):
return (
@@ -47,6 +40,14 @@ class NowDoingChannel(Channel):
task.done_at.isoformat() if task.done_at else None,
)
async def reload_tasklist(self, websocket=None):
"""
Clear tasklist and re-send current tasks.
"""
await self.send_clear(websocket=websocket)
for task in self.cog.tasks.values():
await self.send_set(*self.task_args(task), websocket=websocket)
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
await self.send_event({
'type': "DO",
@@ -60,28 +61,29 @@ class NowDoingChannel(Channel):
}
}, websocket=websocket)
async def send_del(self, userid):
async def send_del(self, userid, websocket=None):
await self.send_event({
'type': "DO",
'method': "delTask",
'args': {
'userid': userid,
}
})
}, websocket=websocket)
async def send_clear(self):
async def send_clear(self, websocket=None):
await self.send_event({
'type': "DO",
'method': "clearTasks",
'args': {
}
})
}, websocket=websocket)
class NowDoingCog(commands.Cog):
def __init__(self, bot: CrocBot):
class NowDoingCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(NowListData())
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(NowListData())
self.channel = NowDoingChannel(self)
register_channel(self.channel.name, self.channel)
@@ -92,23 +94,86 @@ class NowDoingCog(commands.Cog):
async def cog_load(self):
await self.data.init()
await self.load_tasks()
self.bot.get_cog('ProfileCog').add_profile_migrator(self.migrate_profiles, name='task-migrator')
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
"""
Hack because lib devs decided to remove async cog loading.
"""
if not self.loaded.is_set():
await self.cog_load()
async def cog_unload(self):
self.loaded.clear()
self.tasks.clear()
if profiles := self.bot.get_cog('ProfileCog'):
profiles.del_profile_migrator('task-migrator')
self._unload_twitch_methods(self.crocbot)
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
"""
Move current source task to target profile if there's room for it, otherwise annihilate
"""
await self.load_tasks()
source_task = self.tasks.pop(source_profile.profileid, None)
results = ["(Tasklist)"]
if source_task:
target_task = self.tasks.get(target_profile.profileid, None)
if target_task and (target_task.is_done or target_task.started_at < source_task.started_at):
# If target is done, remove it so we can overwrite
results.append("Removed older task from target profile.")
await target_task.delete()
target_task = None
if not target_task:
# Update source task with new profile id
await source_task.update(userid=target_profile.profileid)
target_task = source_task
await self.channel.send_set(*self.channel.task_args(target_task))
results.append("Migrated 1 currently running task from source profile.")
else:
# If there is a target task we can't overwrite, just delete the source task
await source_task.delete()
results.append("Ignoring and removing older task from source profile.")
self.tasks.pop(source_profile.profileid, None)
await self.channel.send_del(source_profile.profileid)
else:
results.append("No running task in source profile, nothing to migrate!")
await self.load_tasks()
return ' '.join(results)
async def user_profile_migration(self):
"""
Manual single-use migration method from the old userid format to the new profileid format.
"""
await self.load_tasks()
for userid, task in self.tasks.items():
userid = int(userid)
if userid > 1000:
# Assume it is a twitch userid
profile = await UserProfile.fetch_from_twitchid(self.bot, userid)
if not profile:
# Create a new profile with this twitch user
users = await self.crocbot.fetch_users(ids=[userid])
if not users:
continue
user = users[0]
profile = await UserProfile.create_from_twitch(self.bot, user)
if not await self.data.Task.fetch(profile.profileid):
await task.update(userid=profile.profileid)
else:
await task.delete()
await self.load_tasks()
await self.channel.reload_tasklist()
async def cog_check(self, ctx):
await self.ensure_loaded()
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
async def load_tasks(self):
@@ -123,24 +188,27 @@ class NowDoingCog(commands.Cog):
# await self.channel.send_test_set()
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
# await ctx.send(str(list(self.tasks.items())[0]))
await self.user_profile_migration()
await ctx.send(str(ctx.author.id))
await ctx.reply("Userid -> profile migration done.")
else:
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
@commands.command(aliases=['task', 'check'])
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
async def now(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str] = None, edit=False):
args = args.strip() if args else None
userid = profile.profileid
if args:
existing = self.tasks.get(userid, None)
await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create(
userid=userid,
name=ctx.author.display_name,
task=args,
started_at=utc_now(),
started_at=existing.started_at if (existing and edit) else utc_now(),
)
self.tasks[task.userid] = task
await self.channel.send_set(*self.channel.task_args(task))
await ctx.send(f"Updated your current task, good luck!")
await ctx.send("Updated your current task, good luck!")
elif task := self.tasks.get(userid, None):
if task.is_done:
done_ago = strfdelta(utc_now() - task.done_at)
@@ -158,9 +226,38 @@ class NowDoingCog(commands.Cog):
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command(name='next')
async def nownext(self, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
@commands.command(
name='now',
aliases=['task', 'check']
)
async def twi_now(self, ctx: commands.Context, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
await self.now(ctx, profile, args)
@cmds.hybrid_command(
name='now',
aliases=['task', 'check']
)
async def disc_now(self, ctx: LionContext, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
await self.now(ctx, profile, args)
@commands.command(
name='edit',
)
async def twi_edit(self, ctx: commands.Context, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
await self.now(ctx, profile, args, edit=True)
@cmds.hybrid_command(
name='edit',
)
async def disc_edit(self, ctx: LionContext, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
await self.now(ctx, profile, args, edit=True)
async def nownext(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str]):
userid = profile.profileid
task = self.tasks.get(userid, None)
if args:
if task:
@@ -181,7 +278,7 @@ class NowDoingCog(commands.Cog):
)
self.tasks[task.userid] = task
await self.channel.send_set(*self.channel.task_args(task))
await ctx.send(f"Next task set, good luck!" + ' ' + prefix)
await ctx.send("Next task set, good luck!" + ' ' + prefix)
elif task:
if task.is_done:
done_ago = strfdelta(utc_now() - task.done_at)
@@ -199,9 +296,22 @@ class NowDoingCog(commands.Cog):
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command()
async def done(self, ctx: commands.Context):
userid = int(ctx.author.id)
@commands.command(
name='next',
)
async def twi_next(self, ctx: commands.Context, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
await self.nownext(ctx, profile, args)
@cmds.hybrid_command(
name='next',
)
async def disc_next(self, ctx: LionContext, *, args: Optional[str] = None):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
await self.nownext(ctx, profile, args)
async def done(self, ctx: commands.Context | LionContext, profile: UserProfile):
userid = profile.profileid
if task := self.tasks.get(userid, None):
if task.is_done:
await ctx.send(
@@ -221,9 +331,36 @@ class NowDoingCog(commands.Cog):
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command()
async def clear(self, ctx: commands.Context):
userid = int(ctx.author.id)
@commands.command(
name='done',
)
async def twi_done(self, ctx: commands.Context):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
await self.done(ctx, profile)
@cmds.hybrid_command(
name='done',
)
async def disc_done(self, ctx: LionContext):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
await self.done(ctx, profile)
@commands.command(
name='clear',
)
async def twi_clear(self, ctx: commands.Context):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
await self.clear(ctx, profile)
@cmds.hybrid_command(
name='clear',
)
async def disc_clear(self, ctx: LionContext):
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
await self.clear(ctx, profile)
async def clear(self, ctx: commands.Context | LionContext, profile):
userid = profile.profileid
if task := self.tasks.pop(userid, None):
await task.delete()
await self.channel.send_del(userid)

View File

@@ -47,16 +47,32 @@ class TimerChannel(Channel):
super().__init__(**kwargs)
self.cog = cog
self.channelid = 1261999440160624734
self.goal = 12
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
timer = self.cog.get_channel_timer(1261999440160624734)
if timer is not None:
await self.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
websocket=websocket,
)
await self.send_set(
**await self.get_args_for(self.channelid),
goal=self.goal,
websocket=websocket,
)
async def send_updates(self):
await self.send_set(
**await self.get_args_for(self.channelid),
goal=self.goal,
)
async def get_args_for(self, channelid):
timer = self.cog.get_channel_timer(channelid)
if timer is None:
raise ValueError(f"Timer {channelid} doesn't exist.")
return {
'start_at': timer.data.last_started,
'focus_length': timer.data.focus_length,
'break_length': timer.data.break_length,
}
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
await self.send_event({
@@ -304,8 +320,6 @@ class TimerCog(LionCog):
return
if member.bot:
return
if 1148167212901859328 not in [role.id for role in member.roles]:
return
# If a member is leaving or joining a running timer, trigger a status update
if before.channel != after.channel:
@@ -315,6 +329,7 @@ class TimerCog(LionCog):
tasks = []
if leaving is not None:
tasks.append(asyncio.create_task(leaving.update_status_card()))
leaving.last_seen.pop(member.id, None)
if joining is not None:
joining.last_seen[member.id] = utc_now()
if not joining.running and joining.auto_restart:
@@ -1059,8 +1074,18 @@ class TimerCog(LionCog):
@low_management_ward
async def streamtimer_update_cmd(self, ctx: LionContext,
new_start: Optional[str] = None,
new_goal: int = 12):
timer = self.get_channel_timer(1261999440160624734)
new_goal: Optional[int] = None,
new_channel: Optional[discord.VoiceChannel] = None,
):
if new_channel is not None:
channelid = self.channel.channelid = new_channel.id
else:
channelid = self.channel.channelid
if new_goal is not None:
self.channel.goal = new_goal
timer = self.get_channel_timer(channelid)
if timer is None:
return
if new_start:
@@ -1068,10 +1093,5 @@ class TimerCog(LionCog):
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
await timer.data.update(last_started=start_at)
await self.channel.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
goal=new_goal,
)
await self.channel.send_updates()
await ctx.reply("Stream Timer Updated")

View File

@@ -8,10 +8,12 @@ from gui.cards import FocusTimerCard, BreakTimerCard
if TYPE_CHECKING:
from .timer import Timer, Stage
from tracking.voice.cog import VoiceTrackerCog
from modules.nowdoing.cog import NowDoingCog
async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
voicecog: 'VoiceTrackerCog' = bot.get_cog('VoiceTrackerCog')
nowcog: 'NowDoingCog' = bot.get_cog('NowDoingCog')
name = timer.base_name
if stage is not None:
@@ -23,16 +25,22 @@ async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
card_users = []
guildid = timer.data.guildid
for member in timer.members:
if voicecog is not None:
session = voicecog.get_session(guildid, member.id)
tag = session.tag
if session.start_time:
session_duration = (utc_now() - session.start_time).total_seconds()
else:
session_duration = 0
profile = await bot.get_cog('ProfileCog').fetch_profile_discord(member)
task = nowcog.tasks.get(profile.profileid, None)
tag = ''
session_duration = 0
if task:
tag = task.task
session_duration = ((task.done_at or utc_now()) - task.started_at).total_seconds()
else:
session_duration = 0
tag = None
session = voicecog.get_session(guildid, member.id)
if session:
tag = session.tag
if session.start_time:
session_duration = (utc_now() - session.start_time).total_seconds()
else:
session_duration = 0
card_user = (
(member.id, (member.avatar or member.default_avatar).key),

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import ProfileCog
async def setup(bot):
await bot.add_cog(ProfileCog(bot))

415
src/modules/profiles/cog.py Normal file
View File

@@ -0,0 +1,415 @@
import asyncio
from enum import Enum
from typing import Optional, overload
from datetime import timedelta
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from twitchio import User
from twitchAPI.object.api import TwitchUser
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot, LionContext
from meta.logger import log_wrap
from utils.lib import utc_now
from . import logger
from .data import ProfileData
from .profile import UserProfile
from .community import Community
class ProfileCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
assert bot.crocbot is not None
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(ProfileData())
self._profile_migrators = {}
self._comm_migrators = {}
async def cog_load(self):
await self.data.init()
async def cog_check(self, ctx):
return True
async def bot_check_once(self, ctx: LionContext):
"""
Inject the contextual UserProfile and Community into the LionContext.
Creates the profile and community if they do not exist.
"""
if ctx.guild:
ctx.community = await self.fetch_community_discord(ctx.guild)
ctx.profile = await self.fetch_profile_discord(ctx.author)
return True
# Profile API
def add_profile_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._profile_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added user profile migrator {name}: {migrator}"
)
return migrator
def del_profile_migrator(self, name: str):
migrator = self._profile_migrators.pop(name, None)
logger.info(
f"Removed user profile migrator {name}: {migrator}"
)
@log_wrap(action="profile migration")
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
logger.info(
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._profile_migrators.items():
try:
result = await migrator(source_profile, target_profile)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running user profile migrator {name} "
f"migrating {source_profile!r} to {target_profile!r}."
)
raise
# Move all Discord and Twitch profile references over to the new profile
discord_rows = await self.data.DiscordProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
profileid=source_profile.profileid
).set(profileid=target_profile.profileid)
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
# And then mark the old profile as migrated
await source_profile.profile_row.update(migrated=target_profile.profileid)
results.append("Marking old profile as migrated.. finished!")
return results
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
"""
Fetch a UserProfile by the given id.
"""
return await UserProfile.fetch(self.bot, profile_id=profile_id)
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided discord account.
"""
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_discord(self.bot, user)
return profile
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
"""
Fetch or create a UserProfile from the provided twitch account.
"""
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if profile is None:
profile = await UserProfile.create_from_twitch(self.bot, user)
return profile
# Community API
def add_community_migrator(self, migrator, name=None):
name = name or migrator.__name__
self._comm_migrators[name or migrator.__name__] = migrator
logger.info(
f"Added community migrator {name}: {migrator}"
)
return migrator
def del_community_migrator(self, name: str):
migrator = self._comm_migrators.pop(name, None)
logger.info(
f"Removed community migrator {name}: {migrator}"
)
@log_wrap(action="community migration")
async def migrate_community(self, source_comm, target_comm) -> list[str]:
logger.info(
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
)
results = []
# Wrap this in a transaction so if something goes wrong with migration,
# we roll back safely (although this may mess up caches)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
for name, migrator in self._comm_migrators.items():
try:
result = await migrator(source_comm, target_comm)
if result:
results.append(result)
except Exception:
logger.exception(
f"Unexpected exception running community migrator {name} "
f"migrating {source_comm!r} to {target_comm!r}."
)
raise
# Move all Discord and Twitch community preferences over to the new profile
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
profileid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
communityid=source_comm.communityid
).set(communityid=target_comm.communityid)
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
# And then mark the old community as migrated
await source_comm.update(migrated=target_comm.communityid)
results.append("Marking old community as migrated.. finished!")
return results
async def fetch_community_by_id(self, community_id: int) -> Community:
"""
Fetch a Community by the given id.
"""
return await Community.fetch(self.bot, community_id=community_id)
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
"""
Fetch or create a Community from the provided discord guild.
"""
comm = await Community.fetch_from_discordid(self.bot, guild.id)
if comm is None:
comm = await Community.create_from_discord(self.bot, guild)
return comm
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
"""
Fetch or create a Community from the provided twitch account.
"""
community = await Community.fetch_from_twitchid(self.bot, user.id)
if community is None:
community = await Community.create_from_twitch(self.bot, user)
return community
# ----- Profile Commands -----
@cmds.hybrid_group(
name='profiles',
description="Base comand group for user profiles."
)
async def profiles_grp(self, ctx: LionContext):
...
@profiles_grp.group(
name='link',
description="Base command group for linking profiles"
)
async def profiles_link_grp(self, ctx: LionContext):
...
@profiles_link_grp.command(
name='twitch',
description="Link a twitch account to your current profile."
)
async def profiles_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
await ctx.interaction.response.defer(ephemeral=True)
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth()
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
"to Twitch."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning profile merge..."
)
results = await self.crocbot.fetch_users(ids=[authrow.userid])
if not results:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
user = results[0]
# Retrieve author's profile if it exists
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
# Check if the twitch-side user has a profile
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
if author_profile and source_profile is None:
# All we need to do is attach the twitch row
await author_profile.attach_twitch(user)
await message.edit(
content=f"Successfully added Twitch account **{user.name}**! There was no profile data to merge."
)
elif source_profile and author_profile is None:
# Attach the discord row to the profile
await source_profile.attach_discord(ctx.author)
await message.edit(
content=f"Successfully connected to Twitch profile **{user.name}**! There was no profile data to merge."
)
elif source_profile is None and author_profile is None:
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
await profile.attach_twitch(user)
await message.edit(
content=f"Opened a new user profile for you and linked Twitch account **{user.name}**."
)
elif author_profile.profileid == source_profile.profileid:
await message.edit(
content=f"The Twitch account **{user.name}** is already linked to your profile!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_profile(source_profile, author_profile)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your account profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging profiles...",
*results,
"**Successfully linked account and merged profile data!**"
))
await message.edit(content=content)
# ----- Community Commands -----
@cmds.hybrid_group(
name='community',
description="Base comand group for community profiles."
)
async def community_grp(self, ctx: LionContext):
...
@community_grp.group(
name='link',
description="Base command group for linking communities"
)
async def community_link_grp(self, ctx: LionContext):
...
@community_link_grp.command(
name='twitch',
description="Link a twitch account to this community."
)
@appcmds.guild_only()
@appcmds.default_permissions(manage_guild=True)
async def comm_link_twitch_cmd(self, ctx: LionContext):
if not ctx.interaction:
return
assert ctx.guild is not None
await ctx.interaction.response.defer(ephemeral=True)
if not ctx.author.guild_permissions.manage_guild:
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
return
# Ask the user to go through auth to get their userid
auth_cog = self.bot.get_cog('TwitchAuthCog')
flow = await auth_cog.start_auth(
scopes=[
AuthScope.CHAT_EDIT,
AuthScope.CHAT_READ,
AuthScope.MODERATION_READ,
AuthScope.CHANNEL_BOT,
]
)
message = await ctx.reply(
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
)
authrow = await flow.run()
await message.edit(
content="Authentication Complete! Beginning community profile merge..."
)
results = await self.crocbot.fetch_users(ids=[authrow.userid])
if not results:
logger.error(
f"User {authrow} obtained from Twitch authentication does not exist."
)
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
return
user = results[0]
# Retrieve author's profile if it exists
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
# Check if the twitch-side user has a profile
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
if guild_comm and twitch_comm is None:
# All we need to do is attach the twitch row
await guild_comm.attach_twitch(user)
await message.edit(
content=f"Successfully linked Twitch channel **{user.name}**! There was no community data to merge."
)
elif twitch_comm and guild_comm is None:
# Attach the discord row to the profile
await twitch_comm.attach_discord(ctx.guild)
await message.edit(
content=f"Successfully connected to Twitch channel **{user.name}**!"
)
elif twitch_comm is None and guild_comm is None:
profile = await Community.create_from_discord(self.bot, ctx.guild)
await profile.attach_twitch(user)
await message.edit(
content=f"Created a new community for this server and linked Twitch account **{user.name}**."
)
elif guild_comm.communityid == twitch_comm.communityid:
await message.edit(
content=f"This server is already linked to the Twitch channel **{user.name}**!"
)
else:
# Migrate the existing profile data to the new profiles
try:
results = await self.migrate_community(twitch_comm, guild_comm)
except Exception:
await ctx.error_reply(
"An issue was encountered while merging your community profiles!\n"
"Migration rolled back, no data has been lost.\n"
"The developer has been notified. Please try again later!"
)
raise
content = '\n'.join((
"## Connecting Twitch account and merging community profiles...",
*results,
"**Successfully linked account and merged community data!**"
))
await message.edit(content=content)

View File

@@ -0,0 +1,123 @@
from typing import Optional, Self
import discord
import twitchio
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class Community:
def __init__(self, bot: LionBot, community_row):
self.bot = bot
self.row: ProfileData.CommunityRow = community_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def communityid(self):
return self.row.communityid
def __repr__(self):
return f"<Community communityid={self.communityid} row={self.row}>"
async def attach_discord(self, guild: discord.Guild):
"""
Attach a new discord guild to this community.
Assumes the discord guild is not already associated to a community.
"""
discord_row = await self.data.DiscordCommunityRow.create(
communityid=self.communityid,
guildid=guild.id
)
logger.info(
f"Attached discord guild {guild!r} to community {self!r}"
)
return discord_row
async def attach_twitch(self, user: twitchio.User):
"""
Attach a new Twitch user channel to this community.
"""
twitch_row = await self.data.TwitchCommunityRow.create(
communityid=self.communityid,
channelid=str(user.id)
)
logger.info(
f"Attached twitch channel {user!r} to community {self!r}"
)
return twitch_row
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
"""
Fetch the Discord guild rows associated to this community.
"""
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
"""
Fetch the Twitch user rows associated to this profile.
"""
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
@classmethod
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
if community_row is None:
raise ValueError("Provided community_id does not exist.")
return cls(bot, community_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
if rows:
return await cls.fetch(bot, rows[0].communityid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty community with the given initial arguments.
Communities should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial preferences (e.g. name, avatar).
"""
# Create a new community
data = bot.get_cog('ProfileCog').data
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
return await cls.fetch(bot, row.communityid)
@classmethod
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
"""
Create a new community using the given Discord guild as a base.
"""
self = await cls.create(bot, **kwargs)
await self.attach_discord(guild)
return self
@classmethod
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
"""
Create a new profile using the given Twitch channel user as a base.
"""
self = await cls.create(bot, **kwargs)
await self.attach_twitch(user)
return self

View File

@@ -0,0 +1,158 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class ProfileData(Registry):
class UserProfileRow(RowModel):
"""
Schema
------
CREATE TABLE user_profiles(
profileid SERIAL PRIMARY KEY,
nickname TEXT,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'user_profiles'
_cache_ = {}
profileid = Integer(primary=True)
nickname = String()
migrated = Integer()
created_at = Timestamp()
class DiscordProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_discord(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
"""
_tablename_ = 'profiles_discord'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = Integer()
created_at = Integer()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class TwitchProfileRow(RowModel):
"""
Schema
------
CREATE TABLE profiles_twitch(
linkid SERIAL PRIMARY KEY,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
userid TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
"""
_tablename_ = 'profiles_twitch'
_cache_ = {}
linkid = Integer(primary=True)
profileid = Integer()
userid = String()
created_at = Timestamp()
@classmethod
async def fetch_profile(cls, profileid: int):
rows = await cls.fetch_where(profiled=profileid)
return next(rows, None)
class CommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities(
communityid SERIAL PRIMARY KEY,
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities'
_cache_ = {}
communityid = Integer(primary=True)
migrated = Integer()
created_at = Timestamp()
class DiscordCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_discord(
guildid BIGINT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_discord'
_cache_ = {}
guildid = Integer(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class TwitchCommunityRow(RowModel):
"""
Schema
------
CREATE TABLE communities_twitch(
channelid TEXT PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'communities_twitch'
_cache_ = {}
channelid = String(primary=True)
communityid = Integer()
linked_at = Timestamp()
@classmethod
async def fetch_community(cls, communityid: int):
rows = await cls.fetch_where(communityd=communityid)
return next(rows, None)
class CommunityMemberRow(RowModel):
"""
Schema
------
CREATE TABLE community_members(
memberid SERIAL PRIMARY KEY,
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
"""
_tablename_ = 'community_members'
_cache_ = {}
memberid = Integer(primary=True)
communityid = Integer()
profileid = Integer()
created_at = Timestamp()

View File

@@ -0,0 +1,124 @@
from typing import Optional, Self
import discord
import twitchio
from meta import LionBot
from utils.lib import utc_now
from . import logger
from .data import ProfileData
class UserProfile:
def __init__(self, bot: LionBot, profile_row):
self.bot = bot
self.profile_row: ProfileData.UserProfileRow = profile_row
@property
def cog(self):
return self.bot.get_cog('ProfileCog')
@property
def data(self) -> ProfileData:
return self.cog.data
@property
def profileid(self):
return self.profile_row.profileid
def __repr__(self):
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
async def attach_discord(self, user: discord.User | discord.Member):
"""
Attach a new discord user to this profile.
Assumes the discord user does not itself have a profile.
"""
discord_row = await self.data.DiscordProfileRow.create(
profileid=self.profileid,
userid=user.id
)
logger.info(
f"Attached discord user {user!r} to profile {self!r}"
)
return discord_row
async def attach_twitch(self, user: twitchio.User):
"""
Attach a new Twitch user to this profile.
"""
twitch_row = await self.data.TwitchProfileRow.create(
profileid=self.profileid,
userid=str(user.id)
)
logger.info(
f"Attached twitch user {user!r} to profile {self!r}"
)
return twitch_row
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
"""
Fetch the Discord accounts associated to this profile.
"""
return await self.data.DiscordProfileRow.fetch_where(profileid=self.profileid)
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
"""
Fetch the Twitch accounts associated to this profile.
"""
return await self.data.TwitchProfileRow.fetch_where(profileid=self.profileid)
@classmethod
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
if profile_row is None:
raise ValueError("Provided profile_id does not exist.")
return cls(bot, profile_row)
@classmethod
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
data = bot.get_cog('ProfileCog').data
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
if rows:
return await cls.fetch(bot, rows[0].profileid)
@classmethod
async def create(cls, bot: LionBot, **kwargs) -> Self:
"""
Create a new empty profile with the given initial arguments.
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
to correctly setup initial profile preferences (e.g. name, avatar).
"""
# Create a new profile
data = bot.get_cog('ProfileCog').data
profile_row = await data.UserProfileRow.create(created_at=utc_now())
profile = await cls.fetch(bot, profile_row.profileid)
return profile
@classmethod
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
"""
Create a new profile using the given Discord user as a base.
"""
profile = await cls.create(bot, **kwargs)
await profile.attach_discord(user)
return profile
@classmethod
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
"""
Create a new profile using the given Twitch user as a base.
"""
profile = await cls.create(bot, **kwargs)
await profile.attach_twitch(user)
return profile

View File

@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import ShoutoutCog
def prepare(bot):
bot.add_cog(ShoutoutCog(bot))
async def setup(bot):
await bot.add_cog(ShoutoutCog(bot))

View File

@@ -4,50 +4,60 @@ from typing import Optional
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionBot, LionCog
from utils.lib import replace_multiple
from . import logger
from .data import ShoutoutData
class ShoutoutCog(commands.Cog):
class ShoutoutCog(LionCog):
# Future extension: channel defaults and config
DEFAULT_SHOUTOUT = """
We think that {name} is a great streamer and you should check them out \
and drop a follow! \
They {areorwere} streaming {game} at {channel}
"""
def __init__(self, bot: CrocBot):
COWO_SHOUTOUT = """
We think that {name} is a great coworker and you should check them out for more productive vibes! \
They {areorwere} streaming {game} at {channel}
"""
ART_SHOUTOUT = """
We think that {name} is an awesome artist and you should check them out for cool art and cosy vibes! \
They {areorwere} streaming {game} at {channel}
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(ShoutoutData())
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(ShoutoutData())
self.loaded = asyncio.Event()
async def cog_load(self):
await self.data.init()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_unload(self):
self.loaded.clear()
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
await self.ensure_loaded()
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
async def format_shoutout(self, text: str, user: twitchio.User):
channels = await self.bot.fetch_channels([user.id])
channels = await self.crocbot.fetch_channels([user.id])
if channels:
channel = channels[0]
game = channel.game_name or 'Unknown'
else:
game = 'Unknown'
streams = await self.bot.fetch_streams([user.id])
streams = await self.crocbot.fetch_streams([user.id])
live = bool(streams)
mapping = {
@@ -59,19 +69,28 @@ class ShoutoutCog(commands.Cog):
return replace_multiple(text, mapping)
@commands.command(aliases=['so'])
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
async def shoutout(self, ctx: commands.Context, target: str, typ: Optional[str]=None):
# Make sure caller is mod/broadcaster
# Lookup custom shoutout for this user
# If it exists use it, otherwise use default shoutout
if (ctx.author.is_mod or ctx.author.is_broadcaster):
data = await self.data.CustomShoutout.fetch(int(user.id))
if data:
shoutout = data.content
user = await self.crocbot.seek_user(target)
if user is None:
await ctx.reply(f"Couldn't resolve '{target}' to a valid user.")
else:
shoutout = self.DEFAULT_SHOUTOUT
formatted = await self.format_shoutout(shoutout, user)
await ctx.reply(formatted)
data = await self.data.CustomShoutout.fetch(int(user.id))
if data:
shoutout = data.content
elif typ == 'cowo':
shoutout = self.COWO_SHOUTOUT
elif typ == 'art':
shoutout = self.ART_SHOUTOUT
else:
shoutout = self.DEFAULT_SHOUTOUT
formatted = await self.format_shoutout(shoutout, user)
await ctx.reply(formatted)
# TODO: How to /shoutout with lib?
# TODO Shoutout queue
@commands.command()
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):

View File

@@ -4,5 +4,5 @@ logger = logging.getLogger(__name__)
from .cog import TagCog
def prepare(bot):
bot.add_cog(TagCog(bot))
async def setup(bot):
await bot.add_cog(TagCog(bot))

View File

@@ -6,16 +6,17 @@ import difflib
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta import CrocBot, LionBot, LionCog
from utils.lib import utc_now
from . import logger
from .data import TagData
class TagCog(commands.Cog):
def __init__(self, bot: CrocBot):
class TagCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.data.load_registry(TagData())
self.crocbot = bot.crocbot
self.data = bot.db.load_registry(TagData())
self.loaded = asyncio.Event()
@@ -31,19 +32,24 @@ class TagCog(commands.Cog):
self.tags.clear()
self.tags.update(tags)
logger.info(f"Loaded {len(tags)} into cache.")
async def cog_load(self):
await self.data.init()
await self.load_tags()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
async def cog_unload(self):
self.loaded.clear()
self.tags.clear()
self._unload_twitch_methods(self.crocbot)
@commands.Cog.event('event_ready')
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
if not self.loaded.is_set():
await ctx.reply("Tasklists are still loading! Please wait a moment~")
return False
return True
# API

View File

@@ -7,9 +7,12 @@ from discord.ext import commands as cmds
from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from data.queries import JOINTYPE
from meta import LionBot, LionCog, LionContext
from meta.CrocBot import CrocBot
from meta.logger import log_wrap
from meta.errors import UserInputError
from modules.profiles.profile import UserProfile
from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton
@@ -126,6 +129,7 @@ class TasklistCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(TasklistData())
self.babel = babel
self.settings = TasklistSettings()
@@ -138,10 +142,84 @@ class TasklistCog(LionCog):
self.bot.core.guild_config.register_model_setting(self.settings.task_reward_limit)
self.bot.add_view(TasklistCaller(self.bot))
self.bot.profiles.add_profile_migrator(self.migrate_profiles, name='tasklist-migrator')
configcog = self.bot.get_cog('ConfigCog')
self.crossload_group(self.configure_group, configcog.config_group)
@LionCog.listener('on_tasks_completed')
self._load_twitch_methods(self.crocbot)
async def cog_unload(self):
self.live_tasklists.clear()
if profiles := self.bot.get_cog('ProfileCog'):
profiles.del_profile_migrator('tasklist-migrator')
self._unload_twitch_methods(self.crocbot)
@log_wrap(action="Tasklist Profile Migration")
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
"""
Re-assign all tasklist tasks from source profile to target profile.
TODO: Probably wants some elegant handling of the cached or running tasklists.
"""
results = ["(Tasklist)"]
sourceid = source_profile.profileid
targetid = target_profile.profileid
updated = await self.data.Task.table.update_where(userid=sourceid).set(userid=targetid)
if updated:
results.append(
f"Migrated {len(updated)} task row(s) from source profile."
)
for channel_lists in self.live_tasklists.get(sourceid, []):
for tasklist in list(channel_lists.values()):
await tasklist.close()
self.bot.dispatch('tasklist_update', profileid=targetid, summon=False)
else:
results.append(
"No tasks found in source profile, nothing to migrate!"
)
return ' '.join(results)
async def user_profile_migration(self):
"""
Manual one-shot migration method from old Discord userids to the new profileids.
"""
# First collect all the distinct userids from the tasklist
# Then create a map of userids to profileids, creating the profiles if required
# Then do updates, we can just inefficiently do updates on each distinct userid
# As long as the userids and profileids never overlap, this is fine. Fine for a one-shot
# Extract all the userids that exist in the table
rows = await self.data.Task.table.select_where().select(
userid="DISTINCT(userid)"
).with_no_adapter()
# Fetch or create discord user profiles for them
profile_map = {}
for row in rows:
userid = row['userid']
if userid > 100000:
# Assume a Discord snowflake
profile = await UserProfile.fetch_from_discordid(self.bot, userid)
if not profile:
try:
user = self.bot.get_user(userid)
if user is None:
user = await self.bot.fetch_user(userid)
except discord.HTTPException:
logger.info(f"Skipping user {userid}")
continue
profile = await UserProfile.create_from_discord(self.bot, user)
profile_map[userid] = profile
# Now iterate through
for userid, profile in profile_map.items():
logger.info(f"Migrating userid {userid} to profile {profile}")
await self.data.Task.table.update_where(userid=userid).set(userid=profile.profileid)
# Temporarily disabling integration with userid driven Economy
# @LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
async with self.bot.db.connection() as conn:
@@ -170,6 +248,9 @@ class TasklistCog(LionCog):
)
async def is_tasklist_channel(self, channel) -> bool:
"""
Check whether a given Discord channel is a tasklist channel
"""
if not channel.guild:
return True
channels = (await self.settings.tasklist_channels.get(channel.guild.id)).value
@@ -186,12 +267,16 @@ class TasklistCog(LionCog):
return (channel in channels) or (channel.id in private_channels) or (channel.category in channels)
async def call_tasklist(self, interaction: discord.Interaction):
"""
Given a Discord channel interaction, summon the interacting user's tasklist.
"""
await interaction.response.defer(thinking=True, ephemeral=True)
channel = interaction.channel
guild = channel.guild
userid = interaction.user.id
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
profileid = profile.profileid
tasklist = await Tasklist.fetch(self.bot, self.data, userid)
tasklist = await Tasklist.fetch(self.bot, self.data, profileid)
if await self.is_tasklist_channel(channel):
# Check we have permissions to send a regular message here
@@ -213,7 +298,7 @@ class TasklistCog(LionCog):
)
await interaction.edit_original_response(embed=error)
else:
tasklistui = TasklistUI.fetch(tasklist, channel, guild, timeout=None)
tasklistui = TasklistUI.fetch(tasklist, channel, guild, caller=interaction.user, timeout=None)
await tasklistui.summon(force=True)
await interaction.delete_original_response()
else:
@@ -222,14 +307,14 @@ class TasklistCog(LionCog):
await tasklistui.run(interaction)
@LionCog.listener('on_tasklist_update')
async def update_listening_tasklists(self, userid, channel=None, summon=True):
async def update_listening_tasklists(self, profileid, channel=None, summon=True):
"""
Propagate a tasklist update to all persistent tasklist UIs for this user.
If channel is given, also summons the UI if the channel is a tasklist channel.
"""
# Do the given channel first, and summon if requested
if channel and (tui := TasklistUI._live_[userid].get(channel.id, None)) is not None:
if channel and (tui := TasklistUI._live_[profileid].get(channel.id, None)) is not None:
try:
if summon and await self.is_tasklist_channel(channel):
await tui.summon()
@@ -240,7 +325,7 @@ class TasklistCog(LionCog):
await tui.close()
# Now do the rest of the listening channels
listening = TasklistUI._live_[userid]
listening = TasklistUI._live_[profileid]
for cid, ui in list(listening.items()):
if channel and channel.id == cid:
# We already did this channel
@@ -275,7 +360,7 @@ class TasklistCog(LionCog):
async def tasklist_group(self, ctx: LionContext):
raise NotImplementedError
async def _task_acmpl(self, userid: int, partial: str, multi=False) -> list[appcmds.Choice]:
async def _task_acmpl(self, profileid: int, partial: str, multi=False) -> list[appcmds.Choice]:
"""
Generate a list of task Choices matching a given partial string.
@@ -284,7 +369,7 @@ class TasklistCog(LionCog):
t = self.bot.translator.t
# Should usually be cached, so this won't trigger repetitive db access
tasklist = await Tasklist.fetch(self.bot, self.data, userid)
tasklist = await Tasklist.fetch(self.bot, self.data, profileid)
# Special case for an empty tasklist
if not tasklist.tasklist:
@@ -392,13 +477,17 @@ class TasklistCog(LionCog):
"""
Shared autocomplete for single task parameters.
"""
return await self._task_acmpl(interaction.user.id, partial, multi=False)
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
profileid = profile.profileid
return await self._task_acmpl(profileid, partial, multi=False)
async def tasks_acmpl(self, interaction: discord.Interaction, partial: str) -> list[appcmds.Choice]:
"""
Shared autocomplete for multiple task parameters.
"""
return await self._task_acmpl(interaction.user.id, partial, multi=True)
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
profileid = profile.profileid
return await self._task_acmpl(profileid, partial, multi=True)
@tasklist_group.command(
name=_p('cmd:tasks_new', "new"),
@@ -422,7 +511,7 @@ class TasklistCog(LionCog):
if not ctx.interaction:
return
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
# Fetch parent task if required
@@ -453,9 +542,9 @@ class TasklistCog(LionCog):
)
await ctx.interaction.edit_original_response(
embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
tasklist_new_cmd.autocomplete('parent')(task_acmpl)
@@ -523,7 +612,7 @@ class TasklistCog(LionCog):
raise UserInputError(error)
# Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
taskinfo = tasklist.parse_tasklist(lines)
@@ -572,9 +661,9 @@ class TasklistCog(LionCog):
)
await ctx.interaction.edit_original_response(
embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
@tasklist_group.command(
name=_p('cmd:tasks_edit', "edit"),
@@ -600,7 +689,7 @@ class TasklistCog(LionCog):
t = self.bot.translator.t
if not ctx.interaction:
return
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
# Fetch task to edit
tid = tasklist.parse_label(taskstr) if taskstr else None
@@ -651,12 +740,12 @@ class TasklistCog(LionCog):
await interaction.response.send_message(
embed=embed,
view=(
discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.author.id]
discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid]
else TasklistCaller(self.bot)
),
ephemeral=True
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
if new_content or new_parent:
# Manual edit route
@@ -688,17 +777,17 @@ class TasklistCog(LionCog):
async def tasklist_clear_cmd(self, ctx: LionContext):
t = ctx.bot.translator.t
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
await tasklist.update_tasklist(deleted_at=utc_now())
await ctx.reply(
t(_p(
'cmd:tasks_clear|resp:success',
"Your tasklist has been cleared."
)),
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot),
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot),
ephemeral=True
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
@tasklist_group.command(
name=_p('cmd:tasks_remove', "remove"),
@@ -748,7 +837,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
conditions = []
if taskidstr:
@@ -784,7 +873,7 @@ class TasklistCog(LionCog):
elif completed is False:
conditions.append(self.data.Task.completed_at == NULL)
tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.author.id)
tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.profile.profileid)
if not tasks:
await ctx.interaction.edit_original_response(
embed=error_embed(t(_p(
@@ -813,9 +902,9 @@ class TasklistCog(LionCog):
)
await ctx.interaction.edit_original_response(
embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
tasklist_remove_cmd.autocomplete('taskidstr')(tasks_acmpl)
@@ -844,7 +933,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
try:
taskids = tasklist.parse_labels(taskidstr)
@@ -889,9 +978,9 @@ class TasklistCog(LionCog):
)
await ctx.interaction.edit_original_response(
embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
tasklist_tick_cmd.autocomplete('taskidstr')(tasks_acmpl)
@@ -920,7 +1009,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
try:
taskids = tasklist.parse_labels(taskidstr)
@@ -962,9 +1051,9 @@ class TasklistCog(LionCog):
)
await ctx.interaction.edit_original_response(
embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
)
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
tasklist_untick_cmd.autocomplete('taskidstr')(tasks_acmpl)

View File

@@ -5,6 +5,7 @@ from data.columns import Integer, String, Timestamp, Bool
class TasklistData(Registry):
class Task(RowModel):
"""
Row model describing a single task in a tasklist.
@@ -14,21 +15,17 @@ class TasklistData(Registry):
CREATE TABLE tasklist(
taskid SERIAL PRIMARY KEY,
userid BIGINT NOT NULL REFERENCES user_config ON DELETE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles ON DELETE CASCADE ON UPDATE CASCADE,
parentid INTEGER REFERENCES tasklist (taskid) ON DELETE SET NULL,
content TEXT NOT NULL,
rewarded BOOL DEFAULT FALSE,
deleted_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ,
duration INTEGER,
last_updated_at TIMESTAMPTZ
);
CREATE INDEX tasklist_users ON tasklist (userid);
CREATE TABLE tasklist_channels(
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
channelid BIGINT NOT NULL
);
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
"""
_tablename_ = "tasklist"
@@ -41,5 +38,26 @@ class TasklistData(Registry):
created_at = Timestamp()
deleted_at = Timestamp()
last_updated_at = Timestamp()
duration = Integer()
"""
Schema
------
CREATE TABLE tasklist_channels(
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
channelid BIGINT NOT NULL
);
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
"""
channels = Table('tasklist_channels')
"""
Schema
------
CREATE TABLE current_tasks(
taskid PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
last_started_at TIMESTAMPTZ NOT NULL
);
"""
current_tasks = Table('current_tasks')

View File

@@ -0,0 +1,23 @@
ALTER TABLE tasklist
DROP CONSTRAINT fk_tasklist_users;
ALTER TABLE tasklist
ADD CONSTRAINT fk_tasklist_users
FOREIGN KEY (userid)
REFERENCES user_profiles (profileid)
ON DELETE CASCADE
ON UPDATE CASCADE
NOT VALID;
ALTER TABLE tasklist
ADD COLUMN duration INTEGER;
CREATE TABLE tasklist_current(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
started_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE tasklist_planner(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
sortkey INTEGER
);

View File

@@ -232,13 +232,18 @@ class TasklistUI(BasePager):
def __init__(self,
tasklist: Tasklist,
channel: discord.abc.Messageable, guild: Optional[discord.Guild] = None, **kwargs):
channel: discord.abc.Messageable,
guild: Optional[discord.Guild] = None,
caller: Optional[discord.User | discord.Member] = None,
**kwargs):
kwargs.setdefault('timeout', 600)
super().__init__(**kwargs)
self.bot = tasklist.bot
self.tasklist = tasklist
self.labelled = tasklist.labelled
self.caller = caller
# NOTE: This is now a profiled
self.userid = tasklist.userid
self.channel = channel
self.guild = guild
@@ -449,9 +454,10 @@ class TasklistUI(BasePager):
cascade=True,
completed_at=utc_now()
)
if self.guild:
if (member := self.guild.get_member(self.userid)):
self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
# TODO: Removed economy integration
# if self.guild:
# if (member := self.guild.get_member(self.userid)):
# self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
if to_uncomplete:
await self.tasklist.update_tasks(
*(t.taskid for t in to_uncomplete),
@@ -475,7 +481,7 @@ class TasklistUI(BasePager):
if shared_root:
self._subtree_root = labelled[shared_root].taskid
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
async def _delete_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
await interaction.response.defer()
@@ -486,7 +492,7 @@ class TasklistUI(BasePager):
cascade=True,
deleted_at=utc_now()
)
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
async def _edit_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
if not selected.values:
@@ -513,7 +519,7 @@ class TasklistUI(BasePager):
self._last_parentid = new_parentid
if not subtree:
self._subtree_root = new_parentid
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
await interaction.response.send_modal(editor)
@@ -606,7 +612,7 @@ class TasklistUI(BasePager):
self._subtree_root = pid
await interaction.response.defer()
await self.tasklist.create_task(new_task, parentid=pid)
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
await press.response.send_modal(editor)
@@ -667,7 +673,7 @@ class TasklistUI(BasePager):
@editor.add_callback
async def editor_callback(interaction: discord.Interaction):
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
if sum(len(line) for line in editor.lines.values()) + len(editor.lines) >= 4000:
await press.response.send_message(
@@ -698,7 +704,7 @@ class TasklistUI(BasePager):
await self.tasklist.update_tasklist(
deleted_at=utc_now(),
)
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
async def clear_button_refresh(self):
self.clear_button.label = self.bot.translator.t(_p(
@@ -771,11 +777,12 @@ class TasklistUI(BasePager):
# ----- UI Flow -----
def access_check(self, userid):
return userid == self.userid
return userid in (self.userid, self.caller.id if self.caller else None)
async def interaction_check(self, interaction: discord.Interaction):
t = self.bot.translator.t
if not self.access_check(interaction.user.id):
interaction_profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
if not self.access_check(interaction_profile.profileid):
embed = discord.Embed(
description=t(_p(
'ui:tasklist|error:wrong_user',
@@ -812,10 +819,7 @@ class TasklistUI(BasePager):
total = len(tasks)
completed = sum(t.completed_at is not None for t in tasks)
if self.guild:
user = self.guild.get_member(self.userid)
else:
user = self.bot.get_user(self.userid)
user = self.caller
user_name = user.name if user else str(self.userid)
user_colour = user.colour if user else discord.Color.orange()

View File

@@ -0,0 +1,7 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import VoiceRoleCog
await bot.add_cog(VoiceRoleCog(bot))

View File

@@ -0,0 +1,166 @@
from collections import defaultdict
from typing import Optional
import asyncio
from cachetools import FIFOCache
from weakref import WeakValueDictionary
import discord
from discord.abc import GuildChannel
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from utils.ui import Confirm
from . import logger
from .data import VoiceRoleData
class VoiceRoleCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(VoiceRoleData())
self._event_locks: WeakValueDictionary[tuple[int, int], asyncio.Lock] = WeakValueDictionary()
async def cog_load(self):
await self.data.init()
@LionCog.listener('on_voice_state_update')
@log_wrap(action='Voice Role Update')
async def voicerole_update(self, member: discord.Member,
before: discord.VoiceState, after: discord.VoiceState):
if member.bot:
return
after_channel = after.channel
before_channel = before.channel
if after_channel == before_channel:
return
task_key = (member.guild.id, member.id)
async with self.event_lock(task_key):
# Get the roles of the channel they left to remove
# Get the roles of the channel they are joining to add
# Use a set difference to remove the roles to be added from the ones to remove
if before_channel is not None:
leaving_roles = await self.get_roles_for(before_channel.id)
else:
leaving_roles = []
if after_channel is not None:
gaining_roles = await self.get_roles_for(after_channel.id)
else:
gaining_roles = []
to_remove = []
for role in leaving_roles:
if role in member.roles and role not in gaining_roles and role.is_assignable():
to_remove.append(role)
to_add = []
for role in gaining_roles:
if role not in member.roles and role.is_assignable():
to_add.append(role)
if to_remove:
await member.remove_roles(*to_remove, reason="Removing voice channel associated roles.")
if to_add:
await member.add_roles(*to_add, reason="Adding voice channel associated roles.")
logger.info(
f"Voice roles removed {len(to_remove)} roles "
f"and added {len(to_add)} roles to <uid: {member.id}>"
)
async def get_roles_for(self, channelid: int) -> list[discord.Role]:
"""
Get the voice roles associated to the given channel, as a list.
Returns an empty list if there are no associated voice roles.
"""
rows = await self.data.VoiceRole.fetch_where(channelid=channelid)
channel = self.bot.get_channel(channelid)
if not channel:
raise ValueError("Provided voice role target channel is not in cache.")
target_roles = []
for row in rows:
role = channel.guild.get_role(row.roleid)
if role is not None:
target_roles.append(role)
return target_roles
def event_lock(self, key) -> asyncio.Lock:
"""
Get an asyncio.Lock for the given key.
Guarantees sequential event handling.
"""
lock = self._event_locks.get(key, None)
if lock is None:
lock = self._event_locks[key] = asyncio.Lock()
logger.debug(f"Getting video event lock {key} (locked: {lock.locked()})")
return lock
# -------- Commands --------
@cmds.hybrid_group(
name='voiceroles',
description="Base command group for voice channel -> role associationes."
)
@appcmds.default_permissions(manage_channels=True)
async def voicerole_group(self, ctx: LionContext):
...
@voicerole_group.command(
name="link",
description="Link a given voice channel with a given role."
)
@appcmds.describe(
channel="The voice channel to link.",
role="The associated role to give to members joining the voice channel."
)
async def voicerole_link(self, ctx: LionContext,
channel: discord.VoiceChannel,
role: discord.Role):
if not ctx.interaction:
return
if not channel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
return
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
await ctx.error_reply(f"You don't have the permission to manage this role!")
return
await self.data.VoiceRole.table.insert(channelid=channel.id, roleid=role.id)
await ctx.reply("Voice role associated!")
@voicerole_group.command(
name="unlink",
description="Unlink a given voice channel from a given role."
)
@appcmds.describe(
channel="The voice channel to unlink.",
role="The role to remove from this voice channel."
)
async def voicerole_unlink(self, ctx: LionContext,
channel: discord.VoiceChannel,
role: discord.Role):
if not ctx.interaction:
return
if not channel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
return
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
await ctx.error_reply(f"You don't have the permission to manage this role!")
return
await self.data.VoiceRole.table.delete_where(channelid=channel.id, roleid=role.id)
await ctx.reply("Voice role disassociated!")
# TODO: Display and visual editing of roles.

View File

@@ -0,0 +1,27 @@
from data import Registry, RowModel
from data.columns import Integer, Timestamp
class VoiceRoleData(Registry):
class VoiceRole(RowModel):
"""
Schema
------
CREATE TABLE voice_roles(
voice_role_id SERIAL PRIMARY KEY,
channelid BIGINT NOT NULL,
roleid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX voice_role_channels on voice_roles (channelid);
"""
# TODO: Worth associating a guildid to this as well? Denormalises though
# Makes more theoretical sense to associated configurable channels to the guilds in a join table.
_tablename_ = 'voice_roles'
_cache_ = {}
voice_role_id = Integer(primary=True)
channelid = Integer()
roleid = Integer()
created_at = Timestamp()

View File

@@ -5,7 +5,7 @@ import datetime as dt
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from discord import AllowedMentions, app_commands as appcmds
from data import Condition
from meta import LionBot, LionCog, LionContext
@@ -654,7 +654,7 @@ class VoiceTrackerCog(LionCog):
# ----- Commands -----
@cmds.hybrid_command(
name=_p('cmd:now', "now"),
name="tag",
description=_p(
'cmd:now|desc',
"Describe what you are working on, or see what your friends are working on!"
@@ -668,7 +668,7 @@ class VoiceTrackerCog(LionCog):
@appcmds.describe(
tag=_p(
'cmd:now|param:tag|desc',
"Describe what you are working on in 10 characters or less!"
"Describe what you are working!"
),
user=_p(
'cmd:now|param:user|desc',
@@ -681,17 +681,15 @@ class VoiceTrackerCog(LionCog):
)
@appcmds.guild_only
async def now_cmd(self, ctx: LionContext,
tag: Optional[appcmds.Range[str, 0, 10]] = None,
tag: Optional[str] = None,
*,
user: Optional[discord.Member] = None,
clear: Optional[bool] = None
):
if not ctx.guild:
return
if not ctx.interaction:
return
t = self.bot.translator.t
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
is_moderator = await moderator_ctxward(ctx)
target = user if user is not None else ctx.author
session = self.get_session(ctx.guild.id, target.id, create=False)
@@ -715,7 +713,7 @@ class VoiceTrackerCog(LionCog):
"{mention} has no running session!"
)).format(mention=target.mention)
)
await ctx.interaction.edit_original_response(embed=error)
await ctx.reply(embed=error)
return
if clear:
@@ -723,87 +721,27 @@ class VoiceTrackerCog(LionCog):
if target == ctx.author:
# Clear the author's tag
await session.set_tag(None)
ack = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'cmd:now|target:self|mode:clear|success|title',
"Session Tag Cleared"
)),
description=t(_p(
'cmd:now|target:self|mode:clear|success|desc',
"Successfully unset your session tag."
))
)
ack = "Cleared your current task!"
elif not is_moderator:
# Trying to clear someone else's tag without being a moderator
ack = discord.Embed(
colour=discord.Colour.brand_red(),
title=t(_p(
'cmd:now|target:other|mode:clear|error:perms|title',
"You can't do that!"
)),
description=t(_p(
'cmd:now|target:other|mode:clear|error:perms|desc',
"You need to be a moderator to set or clear someone else's session tag."
))
)
ack = "You need to be a moderator to set or clear someone else's task!"
else:
# Clearing someone else's tag as a moderator
await session.set_tag(None)
ack = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'cmd:now|target:other|mode:clear|success|title',
"Session Tag Cleared!"
)),
description=t(_p(
'cmd:now|target:other|mode:clear|success|desc',
"Cleared {target}'s session tag."
)).format(target=target.mention)
)
ack = f"Cleared {target}'s current task!"
elif tag:
# Tag setting mode
if target == ctx.author:
# Set the author's tag
await session.set_tag(tag)
ack = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'cmd:now|target:self|mode:set|success|title',
"Session Tag Set!"
)),
description=t(_p(
'cmd:now|target:self|mode:set|success|desc',
"You are now working on `{new_tag}`. Good luck!"
)).format(new_tag=tag)
)
ack = f"Set your current task to `{tag}`, good luck! <:goodluck:1266447460146876497>"
elif not is_moderator:
# Trying the set someone else's tag without being a moderator
ack = discord.Embed(
colour=discord.Colour.brand_red(),
title=t(_p(
'cmd:now|target:other|mode:set|error:perms|title',
"You can't do that!"
)),
description=t(_p(
'cmd:now|target:other|mode:set|error:perms|desc',
"You need to be a moderator to set or clear someone else's session tag!"
))
)
ack = "You need to be a moderator to set or clear someone else's task!"
else:
# Setting someone else's tag as a moderator
await session.set_tag(tag)
ack = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'cmd:now|target:other|mode:set|success|title',
"Session Tag Set!"
)),
description=t(_p(
'cmd:now|target:other|mode:set|success|desc',
"Set {target}'s session tag to `{new_tag}`."
)).format(target=target.mention, new_tag=tag)
)
ack = f"Set {target}'s current task to `{tag}`"
else:
# Display tag and voice time
if target == ctx.author:
@@ -815,14 +753,14 @@ class VoiceTrackerCog(LionCog):
else:
desc = t(_p(
'cmd:now|target:self|mode:show_without_tag|desc',
"You have been working in {channel} since {time}!\n\n"
"You have been working in {channel} since {time}! "
"Use `/now <tag>` to set what you are working on."
))
else:
if session.tag:
desc = t(_p(
'cmd:now|target:other|mode:show_with_tag|desc',
"{target} is current working in {channel}!\n"
"{target} is current working in {channel}! "
"They have been working on **{tag}** since {time}."
))
else:
@@ -830,18 +768,13 @@ class VoiceTrackerCog(LionCog):
'cmd:now|target:other|mode:show_without_tag|desc',
"{target} has been working in {channel} since {time}!"
))
desc = desc.format(
ack = desc.format(
tag=session.tag,
channel=f"<#{session.state.channelid}>",
time=discord.utils.format_dt(session.start_time, 't'),
time=discord.utils.format_dt(session.start_time, 'R'),
target=target.mention,
)
ack = discord.Embed(
colour=discord.Colour.orange(),
description=desc,
timestamp=utc_now()
)
await ctx.interaction.edit_original_response(embed=ack)
await ctx.reply(ack, allowed_mentions=AllowedMentions.none())
# ----- Configuration Commands -----
@LionCog.placeholder_group

9
src/twitch/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TwitchAuthCog
async def setup(bot):
await bot.add_cog(TwitchAuthCog(bot))

50
src/twitch/authclient.py Normal file
View File

@@ -0,0 +1,50 @@
"""
Testing client for the twitch AuthServer.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
import asyncio
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.type import AuthScope
from meta.config import conf
URI = "http://localhost:3000/twiauth/confirm"
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
async def main():
# Load in client id and secret
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
url = auth.return_auth_url()
# Post url to user
print(url)
# Send listen request to server
# Wait for listen request
async with aiohttp.ClientSession() as session:
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
await ws.send_json({'state': auth.state})
result = await ws.receive_json()
# Hopefully get back code, print the response
print(f"Recieved: {result}")
# Authorise with code and client details
tokens = await auth.authenticate(user_token=result['code'])
if tokens:
token, refresh = tokens
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
print(f"Authorised!")
if __name__ == '__main__':
asyncio.run(main())

86
src/twitch/authserver.py Normal file
View File

@@ -0,0 +1,86 @@
import logging
import uuid
import asyncio
from contextvars import ContextVar
import aiohttp
from aiohttp import web
logger = logging.getLogger(__name__)
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
class AuthServer:
def __init__(self):
self.listeners = {}
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
args = request.query
if 'state' not in args:
raise web.HTTPBadRequest(text="No state provided.")
if args['state'] not in self.listeners:
raise web.HTTPBadRequest(text="Invalid state.")
self.listeners[args['state']].set_result(dict(args))
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
_reqid = str(uuid.uuid1())
reqid.set(_reqid)
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get the listen request data
try:
listen_req = await ws.receive_json(timeout=60)
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
if 'state' not in listen_req:
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
raise web.HTTPBadRequest(text="Listen request must include state string.")
elif listen_req['state'] in self.listeners:
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
raise web.HTTPBadRequest(text="Invalid state string.")
except ValueError:
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except TypeError:
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
except Exception:
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
raise web.HTTPInternalServerError()
try:
fut = self.listeners[listen_req['state']] = asyncio.Future()
result = await asyncio.wait_for(fut, timeout=120)
except asyncio.TimeoutError:
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
finally:
self.listeners.pop(listen_req['state'], None)
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
await ws.send_json(result)
await ws.close()
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
return ws
def main(argv):
app = web.Application()
server = AuthServer()
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
app.router.add_get("/twiauth/listen", server.handle_listen_request)
logger.info("App setup and configured. Starting now.")
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
if __name__ == '__main__':
import sys
main(sys.argv)

84
src/twitch/cog.py Normal file
View File

@@ -0,0 +1,84 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import discord
from discord.ext import commands as cmds
from twitchAPI.oauth import UserAuthenticator
from twitchAPI.twitch import AuthType, Twitch
from twitchAPI.type import AuthScope
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot
from meta.LionContext import LionContext
from twitch.userflow import UserAuthFlow
from utils.lib import utc_now
from . import logger
from .data import TwitchAuthData
class TwitchAuthCog(LionCog):
DEFAULT_SCOPES = []
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: int):
...
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
Checks whether the given userid is authorised.
If 'scopes' is given, will also check the user has all of the given scopes.
"""
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
has_auth = set(map(str, scopes)).issubset(has_scopes)
else:
has_auth = True
else:
has_auth = False
return has_auth
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
"""
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
return await self.start_auth(to_request)
async def start_auth(self, scopes = []):
# TODO: Work out a way to just clone the current twitch object
# Or can we otherwise build UserAuthenticator without app auth?
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
await flow.setup()
return flow
# ----- Commands -----
@cmds.hybrid_command(name='auth')
async def cmd_auth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
flow = await self.start_auth()
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

79
src/twitch/data.py Normal file
View File

@@ -0,0 +1,79 @@
import datetime as dt
from data import Registry, RowModel, Table
from data.columns import Integer, String, Timestamp
class TwitchAuthData(Registry):
class UserAuthRow(RowModel):
"""
Schema
------
CREATE TABLE twitch_user_auth(
userid TEXT PRIMARY KEY,
access_token TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
refresh_token TEXT NOT NULL,
obtained_at TIMESTAMPTZ
);
"""
_tablename_ = 'twitch_user_auth'
_cache_ = {}
userid = Integer(primary=True)
access_token = String()
refresh_token = String()
expires_at = Timestamp()
obtained_at = Timestamp()
@classmethod
async def update_user_auth(
cls, userid: str, token: str, refresh: str,
expires_at: dt.datetime, obtained_at: dt.datetime,
scopes: list[str]
):
if cls._connector is None:
raise ValueError("Attempting to use uninitialised Registry.")
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Clear row for this userid
await cls.table.delete_where(userid=userid)
# Insert new user row
row = await cls.create(
userid=userid,
access_token=token,
refresh_token=refresh,
expires_at=expires_at,
obtained_at=obtained_at
)
# Insert new scope rows
if scopes:
await TwitchAuthData.user_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in scopes)
)
return row
@classmethod
async def get_scopes_for(cls, userid: str) -> list[str]:
"""
Get a list of scopes stored for the given user.
Will return an empty list if the user is not authenticated.
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row.scope for row in rows] if rows else []
"""
Schema
------
CREATE TABLE twitch_user_scopes(
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
scope TEXT
);
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
"""
user_scopes = Table('twitch_user_scopes')

0
src/twitch/lib.py Normal file
View File

88
src/twitch/userflow.py Normal file
View File

@@ -0,0 +1,88 @@
from typing import Optional
import datetime as dt
from aiohttp import web
import aiohttp
from twitchAPI.twitch import Twitch
from twitchAPI.oauth import UserAuthenticator, validate_token
from twitchAPI.type import AuthType
from twitchio.client import asyncio
from meta.errors import SafeCancellation
from utils.lib import utc_now
from .data import TwitchAuthData
from . import logger
class UserAuthFlow:
auth: UserAuthenticator
data: TwitchAuthData
auth_ws: str
def __init__(self, data, auth, auth_ws):
self.auth = auth
self.data = data
self.auth_ws = auth_ws
self._setup_done = asyncio.Event()
self._comm_task: Optional[asyncio.Task] = None
async def setup(self):
"""
Establishes websocket connection to the AuthServer,
and requests listening for the given state.
Propagates any exceptions that occur during connection setup.
"""
if self._setup_done.is_set():
raise ValueError("UserAuthFlow is already set up.")
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
await self._setup_done.wait()
if self._comm_task.done() and (exc := self._comm_task.exception()):
raise exc
async def _communicate(self):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.auth_ws) as ws:
await ws.send_json({'state': self.auth.state})
self._setup_done.set()
return await ws.receive_json()
async def run(self) -> TwitchAuthData.UserAuthRow:
if not self._setup_done.is_set():
raise ValueError("Cannot run UserAuthFlow before setup.")
if self._comm_task is None:
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
result = await self._comm_task
if result.get('error', None):
# TODO Custom auth errors
# This is only documented to occur when the user denies the auth
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
if result.get('state', None) != self.auth.state:
# This should never happen unless the authserver has its wires crossed somehow,
# or the connection has been tampered with.
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
logger.critical(
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
)
raise SafeCancellation(
"Could not complete authentication! Invalid server response."
)
# Now assume result has a valid code
# Exchange code for an auth token and a refresh token
# Ignore type here, authenticate returns None if a callback function has been given.
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
# Fetch the associated userid and basic info
v_result = await validate_token(token)
userid = v_result['user_id']
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
# Save auth data
return await self.data.UserAuthRow.update_user_auth(
userid=userid, token=token, refresh=refresh,
expires_at=expiry, obtained_at=utc_now(),
scopes=[scope.value for scope in self.auth.scopes]
)

7
tests/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
# !/bin/python3
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd()))
sys.path.insert(0, os.path.join(os.getcwd(), "src"))

0
tests/gui/__init__.py Normal file
View File

View File

@@ -1,11 +1,11 @@
import asyncio
import datetime as dt
from src.cards import WeeklyGoalCard
from gui.cards import WeeklyGoalCard
async def get_card():
card = await WeeklyGoalCard.generate_sample()
with open('samples/weekly-sample.png', 'wb') as image_file:
with open('output/weekly-sample.png', 'wb') as image_file:
image_file.write(card.fp.read())
if __name__ == '__main__':

View File

@@ -0,0 +1,15 @@
import asyncio
import datetime as dt
from gui.cards import BreakTimerCard, FocusTimerCard
async def get_card():
card = await BreakTimerCard.generate_sample()
with open('output/break_timer_sample.png', 'wb') as image_file:
image_file.write(card.fp.read())
card = await FocusTimerCard.generate_sample()
with open('output/focus_timer_sample.png', 'wb') as image_file:
image_file.write(card.fp.read())
if __name__ == '__main__':
asyncio.run(get_card())