Compare commits
37 Commits
timerlayou
...
feat/taskl
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e7a5c9b8a | |||
| 592017ba5e | |||
| 49a8cefeef | |||
| d4870740a2 | |||
| 8991b1a641 | |||
| 79645177bd | |||
| 9b3b7265d3 | |||
| 3c0d527501 | |||
| 997804c6bf | |||
| 2cdd084bbe | |||
| 72d52b6014 | |||
| 92fee23afa | |||
| 83a63e8a6e | |||
| 63152f3475 | |||
| 81e25e7efc | |||
| ce07f7ae73 | |||
| d158aed257 | |||
| 47a52d9600 | |||
| fc459ac0dd | |||
| 45b57b4eca | |||
| 22b99717db | |||
| 2810365588 | |||
| e9946a9814 | |||
| 8f6fdf3381 | |||
| 9d0d19d046 | |||
| a7eb8d0f09 | |||
| 9c738ecb91 | |||
| 9c9107bf9d | |||
| caa907b6d9 | |||
| 44d6d77494 | |||
| f2c449d2e0 | |||
| 53366c0333 | |||
| 66f7680482 | |||
| 37f25f10ef | |||
| 7d327a5e2f | |||
| 41f755795f | |||
| bc073363b9 |
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -6,7 +6,7 @@
|
|||||||
url = git@github.com:Intery/CafeHelper-Skins.git
|
url = git@github.com:Intery/CafeHelper-Skins.git
|
||||||
[submodule "src/modules/voicefix"]
|
[submodule "src/modules/voicefix"]
|
||||||
path = src/modules/voicefix
|
path = src/modules/voicefix
|
||||||
url = https://github.com/Intery/StudyLion-voicefix.git
|
url = git@github.com:Intery/StudyLion-voicefix.git
|
||||||
[submodule "src/modules/streamalerts"]
|
[submodule "src/modules/streamalerts"]
|
||||||
path = src/modules/streamalerts
|
path = src/modules/streamalerts
|
||||||
url = https://github.com/Intery/StudyLion-streamalerts.git
|
url = https://github.com/Intery/StudyLion-streamalerts.git
|
||||||
|
|||||||
113
data/schema.sql
113
data/schema.sql
@@ -287,13 +287,14 @@ CREATE TABLE tasklist(
|
|||||||
deleted_at TIMESTAMPTZ,
|
deleted_at TIMESTAMPTZ,
|
||||||
completed_at TIMESTAMPTZ,
|
completed_at TIMESTAMPTZ,
|
||||||
created_at TIMESTAMPTZ,
|
created_at TIMESTAMPTZ,
|
||||||
last_updated_at TIMESTAMPTZ
|
last_updated_at TIMESTAMPTZ,
|
||||||
|
duration INTEGER
|
||||||
);
|
);
|
||||||
CREATE INDEX tasklist_users ON tasklist (userid);
|
CREATE INDEX tasklist_users ON tasklist (userid);
|
||||||
ALTER TABLE tasklist
|
ALTER TABLE tasklist
|
||||||
ADD CONSTRAINT fk_tasklist_users
|
ADD CONSTRAINT fk_tasklist_users
|
||||||
FOREIGN KEY (userid)
|
FOREIGN KEY (userid)
|
||||||
REFERENCES user_config (userid)
|
REFERENCES user_profiles (profileid)
|
||||||
ON DELETE CASCADE
|
ON DELETE CASCADE
|
||||||
NOT VALID;
|
NOT VALID;
|
||||||
ALTER TABLE tasklist
|
ALTER TABLE tasklist
|
||||||
@@ -317,6 +318,20 @@ CREATE TABLE tasklist_reward_history(
|
|||||||
reward_count INTEGER
|
reward_count INTEGER
|
||||||
);
|
);
|
||||||
CREATE INDEX tasklist_reward_history_users ON tasklist_reward_history (userid, reward_time);
|
CREATE INDEX tasklist_reward_history_users ON tasklist_reward_history (userid, reward_time);
|
||||||
|
|
||||||
|
CREATE TABLE tasklist_current(
|
||||||
|
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE tasklist_planner(
|
||||||
|
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
sortkey INTEGER
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
-- Reminder data {{{
|
-- Reminder data {{{
|
||||||
@@ -1454,6 +1469,7 @@ CREATE TABLE shoutouts(
|
|||||||
CREATE TABLE counters(
|
CREATE TABLE counters(
|
||||||
counterid SERIAL PRIMARY KEY,
|
counterid SERIAL PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
|
category TEXT,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||||
@@ -1464,6 +1480,7 @@ CREATE TABLE counter_log(
|
|||||||
userid INTEGER NOT NULL,
|
userid INTEGER NOT NULL,
|
||||||
value INTEGER NOT NULL,
|
value INTEGER NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
details TEXT,
|
||||||
context_str TEXT
|
context_str TEXT
|
||||||
);
|
);
|
||||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||||
@@ -1484,6 +1501,98 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
|
|||||||
|
|
||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
|
-- Voice Roles {{{
|
||||||
|
CREATE TABLE voice_roles(
|
||||||
|
voice_role_id SERIAL PRIMARY KEY,
|
||||||
|
channelid BIGINT NOT NULL,
|
||||||
|
roleid BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX voice_role_channels on voice_roles (channelid);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- User and Community Profiles {{{
|
||||||
|
DROP TABLE IF EXISTS community_members;
|
||||||
|
DROP TABLE IF EXISTS communities_twitch;
|
||||||
|
DROP TABLE IF EXISTS communities_discord;
|
||||||
|
DROP TABLE IF EXISTS communities;
|
||||||
|
DROP TABLE IF EXISTS profiles_twitch;
|
||||||
|
DROP TABLE IF EXISTS profiles_discord;
|
||||||
|
DROP TABLE IF EXISTS user_profiles;
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE user_profiles(
|
||||||
|
profileid SERIAL PRIMARY KEY,
|
||||||
|
nickname TEXT,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE profiles_discord(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||||
|
|
||||||
|
CREATE TABLE profiles_twitch(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE communities(
|
||||||
|
communityid SERIAL PRIMARY KEY,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE communities_discord(
|
||||||
|
guildid BIGINT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX communities_discord_communityid ON communities_discord (communityid);
|
||||||
|
|
||||||
|
CREATE TABLE communities_twitch(
|
||||||
|
channelid TEXT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX communities_twitch_communityid ON communities_twitch (communityid);
|
||||||
|
|
||||||
|
CREATE TABLE community_members(
|
||||||
|
memberid SERIAL PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- Twitch User Auth {{{
|
||||||
|
CREATE TABLE twitch_user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
obtained_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
-- Analytics Data {{{
|
-- Analytics Data {{{
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ async def main():
|
|||||||
config=conf,
|
config=conf,
|
||||||
initial_extensions=[
|
initial_extensions=[
|
||||||
'utils', 'core', 'analytics',
|
'utils', 'core', 'analytics',
|
||||||
|
'twitch',
|
||||||
'modules',
|
'modules',
|
||||||
'babel',
|
'babel',
|
||||||
'tracking.voice', 'tracking.text',
|
'tracking.voice', 'tracking.text',
|
||||||
|
|||||||
2
src/gui
2
src/gui
Submodule src/gui updated: 40bc140355...62d2484914
@@ -1,9 +1,12 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
from twitchio.ext import pubsub
|
from twitchio.ext import pubsub
|
||||||
|
from twitchio.ext.commands.core import itertools
|
||||||
|
|
||||||
from data import Database
|
from data import Database
|
||||||
|
|
||||||
@@ -23,5 +26,51 @@ class CrocBot(commands.Bot):
|
|||||||
self.data = data
|
self.data = data
|
||||||
self.pubsub = pubsub.PubSubPool(self)
|
self.pubsub = pubsub.PubSubPool(self)
|
||||||
|
|
||||||
|
self._member_cache = defaultdict(dict)
|
||||||
|
|
||||||
async def event_ready(self):
|
async def event_ready(self):
|
||||||
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
|
|
||||||
|
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
|
||||||
|
self._member_cache[channel.name][user.name] = user
|
||||||
|
|
||||||
|
async def event_message(self, message: twitchio.Message):
|
||||||
|
if message.channel and message.author:
|
||||||
|
self._member_cache[message.channel.name][message.author.name] = message.author
|
||||||
|
await self.handle_commands(message)
|
||||||
|
|
||||||
|
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
|
||||||
|
if userstr.startswith('@'):
|
||||||
|
matching = False
|
||||||
|
userstr = userstr.strip('@ ')
|
||||||
|
|
||||||
|
result = None
|
||||||
|
if matching and len(userstr) >= 3:
|
||||||
|
lowered = userstr.lower()
|
||||||
|
full_matches = []
|
||||||
|
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
|
||||||
|
matchstr = user.name.lower()
|
||||||
|
print(matchstr)
|
||||||
|
if matchstr.startswith(lowered):
|
||||||
|
result = user
|
||||||
|
break
|
||||||
|
if lowered in matchstr:
|
||||||
|
full_matches.append(user)
|
||||||
|
if result is None and full_matches:
|
||||||
|
result = full_matches[0]
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
lookup = userstr
|
||||||
|
elif result.id is None:
|
||||||
|
lookup = result.name
|
||||||
|
else:
|
||||||
|
lookup = None
|
||||||
|
|
||||||
|
if lookup:
|
||||||
|
found = await self.fetch_users(names=[lookup])
|
||||||
|
if found:
|
||||||
|
result = found[0]
|
||||||
|
|
||||||
|
# No matches found
|
||||||
|
return result
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from meta.CrocBot import CrocBot
|
from meta.CrocBot import CrocBot
|
||||||
from core.cog import CoreCog
|
from core.cog import CoreCog
|
||||||
from core.config import ConfigCog
|
from core.config import ConfigCog
|
||||||
|
from twitch.cog import TwitchAuthCog
|
||||||
from tracking.voice.cog import VoiceTrackerCog
|
from tracking.voice.cog import VoiceTrackerCog
|
||||||
from tracking.text.cog import TextTrackerCog
|
from tracking.text.cog import TextTrackerCog
|
||||||
from modules.config.cog import GuildConfigCog
|
from modules.config.cog import GuildConfigCog
|
||||||
@@ -49,6 +50,7 @@ if TYPE_CHECKING:
|
|||||||
from modules.topgg.cog import TopggCog
|
from modules.topgg.cog import TopggCog
|
||||||
from modules.user_config.cog import UserConfigCog
|
from modules.user_config.cog import UserConfigCog
|
||||||
from modules.video_channels.cog import VideoCog
|
from modules.video_channels.cog import VideoCog
|
||||||
|
from modules.profiles.cog import ProfileCog
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -90,6 +92,10 @@ class LionBot(Bot):
|
|||||||
def core(self):
|
def core(self):
|
||||||
return self.get_cog('CoreCog')
|
return self.get_cog('CoreCog')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profiles(self):
|
||||||
|
return self.get_cog('ProfileCog')
|
||||||
|
|
||||||
async def _handle_global_dispatch(self, event_name: str, *args, **kwargs):
|
async def _handle_global_dispatch(self, event_name: str, *args, **kwargs):
|
||||||
self.dispatch(event_name, *args, **kwargs)
|
self.dispatch(event_name, *args, **kwargs)
|
||||||
|
|
||||||
@@ -142,6 +148,10 @@ class LionBot(Bot):
|
|||||||
# To make the type checker happy about fetching cogs by name
|
# To make the type checker happy about fetching cogs by name
|
||||||
# TODO: Move this to stubs at some point
|
# TODO: Move this to stubs at some point
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||||
...
|
...
|
||||||
@@ -154,6 +164,10 @@ class LionBot(Bot):
|
|||||||
def get_cog(self, name: Literal['VoiceTrackerCog']) -> 'VoiceTrackerCog':
|
def get_cog(self, name: Literal['VoiceTrackerCog']) -> 'VoiceTrackerCog':
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_cog(self, name: Literal['TextTrackerCog']) -> 'TextTrackerCog':
|
def get_cog(self, name: Literal['TextTrackerCog']) -> 'TextTrackerCog':
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class LionCog(Cog):
|
|||||||
cls._placeholder_groups_ = set()
|
cls._placeholder_groups_ = set()
|
||||||
cls._twitch_cmds_ = {}
|
cls._twitch_cmds_ = {}
|
||||||
cls._twitch_events_ = {}
|
cls._twitch_events_ = {}
|
||||||
|
cls._twitch_events_loaded_ = set()
|
||||||
|
|
||||||
for base in reversed(cls.__mro__):
|
for base in reversed(cls.__mro__):
|
||||||
for elem, value in base.__dict__.items():
|
for elem, value in base.__dict__.items():
|
||||||
@@ -47,6 +48,27 @@ class LionCog(Cog):
|
|||||||
|
|
||||||
return await super()._inject(bot, *args, *kwargs)
|
return await super()._inject(bot, *args, *kwargs)
|
||||||
|
|
||||||
|
def add_twitch_command(self, bot: Bot, command: Command):
|
||||||
|
"""
|
||||||
|
Dynamically register a command with the given bot.
|
||||||
|
|
||||||
|
The command will be deregistered on cog unload.
|
||||||
|
"""
|
||||||
|
# Remove any conflicting commands
|
||||||
|
if cmd := bot.get_command(command.name):
|
||||||
|
bot.remove_command(cmd.name)
|
||||||
|
self._twitch_cmds_.pop(command.name, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._twitch_cmds_[command.name] = command
|
||||||
|
command._instance = self
|
||||||
|
command.cog = self
|
||||||
|
bot.add_command(command)
|
||||||
|
except Exception:
|
||||||
|
# Ensure the command doesn't die in the internal command cache
|
||||||
|
self._twitch_cmds_.pop(command.name, None)
|
||||||
|
raise
|
||||||
|
|
||||||
def _load_twitch_methods(self, bot: Bot):
|
def _load_twitch_methods(self, bot: Bot):
|
||||||
for name, command in self._twitch_cmds_.items():
|
for name, command in self._twitch_cmds_.items():
|
||||||
command._instance = self
|
command._instance = self
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
|
|||||||
from core.lion_member import LionMember
|
from core.lion_member import LionMember
|
||||||
from core.lion_user import LionUser
|
from core.lion_user import LionUser
|
||||||
from core.lion_guild import LionGuild
|
from core.lion_guild import LionGuild
|
||||||
|
from modules.profiles.profile import UserProfile
|
||||||
|
from modules.profiles.community import Community
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -54,6 +56,8 @@ class LionContext(Context['LionBot']):
|
|||||||
lguild: 'LionGuild'
|
lguild: 'LionGuild'
|
||||||
lmember: 'LionMember'
|
lmember: 'LionMember'
|
||||||
alion: 'LionUser | LionMember'
|
alion: 'LionUser | LionMember'
|
||||||
|
profile: 'UserProfile'
|
||||||
|
community: 'Community'
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
parts = {}
|
parts = {}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ this_package = 'modules'
|
|||||||
|
|
||||||
active_discord = [
|
active_discord = [
|
||||||
'.sysadmin',
|
'.sysadmin',
|
||||||
|
'.profiles',
|
||||||
'.config',
|
'.config',
|
||||||
'.user_config',
|
'.user_config',
|
||||||
'.skins',
|
'.skins',
|
||||||
@@ -30,6 +31,7 @@ active_discord = [
|
|||||||
'.nowdoing',
|
'.nowdoing',
|
||||||
'.shoutouts',
|
'.shoutouts',
|
||||||
'.tagstrings',
|
'.tagstrings',
|
||||||
|
'.voiceroles',
|
||||||
]
|
]
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
|
|||||||
@@ -25,6 +25,75 @@ class PERIOD(Enum):
|
|||||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||||
|
|
||||||
|
|
||||||
|
def counter_cmd_factory(
|
||||||
|
counter: str,
|
||||||
|
response: str,
|
||||||
|
default_period: Optional[PERIOD] = PERIOD.STREAM,
|
||||||
|
context: Optional[str] = None
|
||||||
|
):
|
||||||
|
context = context or f"cmd: {counter}"
|
||||||
|
async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
|
userid = int(ctx.author.id)
|
||||||
|
channelid = int((await ctx.channel.user()).id)
|
||||||
|
period, start_time = await cog.parse_period(channelid, '', default=default_period)
|
||||||
|
|
||||||
|
args = (args or '').strip(" ")
|
||||||
|
splits = args.split(maxsplit=1)
|
||||||
|
splits = [split.strip() for split in splits if split]
|
||||||
|
|
||||||
|
details = None
|
||||||
|
amount = 1
|
||||||
|
|
||||||
|
if splits:
|
||||||
|
if splits[0].isdigit() or (splits[0].startswith('-') and splits[0][1:].isdigit()):
|
||||||
|
amount = int(splits[0])
|
||||||
|
splits = splits[1:]
|
||||||
|
if splits:
|
||||||
|
details = ' '.join(splits)
|
||||||
|
|
||||||
|
await cog.add_to_counter(
|
||||||
|
counter, userid, amount,
|
||||||
|
context=context,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
lb = await cog.leaderboard(counter, start_time=start_time)
|
||||||
|
user_total = lb.get(userid, 0)
|
||||||
|
total = sum(lb.values())
|
||||||
|
await ctx.reply(
|
||||||
|
response.format(
|
||||||
|
total=total,
|
||||||
|
period=period,
|
||||||
|
period_name=period.value[0],
|
||||||
|
detailsorname=details or counter,
|
||||||
|
user_total=user_total,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''):
|
||||||
|
user = await ctx.channel.user()
|
||||||
|
await ctx.reply(await cog.formatted_lb(counter, args, int(user.id)))
|
||||||
|
|
||||||
|
async def undo_cmd(cog, ctx: commands.Context):
|
||||||
|
userid = int(ctx.author.id)
|
||||||
|
channelid = int((await ctx.channel.user()).id)
|
||||||
|
_counter = await cog.fetch_counter(counter)
|
||||||
|
query = cog.data.CounterEntry.fetch_where(
|
||||||
|
counterid=_counter.counterid,
|
||||||
|
userid=userid,
|
||||||
|
)
|
||||||
|
query.order_by('created_at', direction=ORDER.DESC)
|
||||||
|
query.limit(1)
|
||||||
|
results = await query
|
||||||
|
if not results:
|
||||||
|
await ctx.reply("Nothing to delete!")
|
||||||
|
else:
|
||||||
|
row = results[0]
|
||||||
|
await row.delete()
|
||||||
|
await ctx.reply("Undo successful!")
|
||||||
|
|
||||||
|
return (counter_cmd, lb_cmd, undo_cmd)
|
||||||
|
|
||||||
|
|
||||||
class CounterCog(LionCog):
|
class CounterCog(LionCog):
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
@@ -38,6 +107,7 @@ class CounterCog(LionCog):
|
|||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
self._load_twitch_methods(self.crocbot)
|
self._load_twitch_methods(self.crocbot)
|
||||||
|
await self.load_counter_commands()
|
||||||
|
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
await self.load_counters()
|
await self.load_counters()
|
||||||
@@ -46,6 +116,29 @@ class CounterCog(LionCog):
|
|||||||
async def cog_unload(self):
|
async def cog_unload(self):
|
||||||
self._unload_twitch_methods(self.crocbot)
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
async def load_counter_commands(self):
|
||||||
|
rows = await self.data.CounterCommand.fetch_where()
|
||||||
|
for row in rows:
|
||||||
|
counter = await self.data.Counter.fetch(row.counterid)
|
||||||
|
counter_cb, lb_cb, undo_cb = counter_cmd_factory(
|
||||||
|
counter.name,
|
||||||
|
row.response
|
||||||
|
)
|
||||||
|
cmds = []
|
||||||
|
main_cmd = commands.command(name=row.name)(counter_cb)
|
||||||
|
cmds.append(main_cmd)
|
||||||
|
if row.lbname:
|
||||||
|
lb_cmd = commands.command(name=row.lbname)(lb_cb)
|
||||||
|
cmds.append(lb_cmd)
|
||||||
|
if row.undoname:
|
||||||
|
undo_cmd = commands.command(name=row.undoname)(undo_cb)
|
||||||
|
cmds.append(undo_cmd)
|
||||||
|
|
||||||
|
for cmd in cmds:
|
||||||
|
self.add_twitch_command(self.crocbot, cmd)
|
||||||
|
|
||||||
|
logger.info(f"(Re)Loaded {len(rows)} counter commands!")
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -80,13 +173,19 @@ class CounterCog(LionCog):
|
|||||||
if row:
|
if row:
|
||||||
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
||||||
|
|
||||||
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
async def add_to_counter(
|
||||||
|
self,
|
||||||
|
counter: str, userid: int, value: int,
|
||||||
|
context: Optional[str]=None,
|
||||||
|
details: Optional[str]=None,
|
||||||
|
):
|
||||||
row = await self.fetch_counter(counter)
|
row = await self.fetch_counter(counter)
|
||||||
return await self.data.CounterEntry.create(
|
return await self.data.CounterEntry.create(
|
||||||
counterid=row.counterid,
|
counterid=row.counterid,
|
||||||
userid=userid,
|
userid=userid,
|
||||||
value=value,
|
value=value,
|
||||||
context_str=context
|
context_str=context,
|
||||||
|
details=details
|
||||||
)
|
)
|
||||||
|
|
||||||
async def leaderboard(self, counter: str, start_time=None):
|
async def leaderboard(self, counter: str, start_time=None):
|
||||||
@@ -155,8 +254,43 @@ class CounterCog(LionCog):
|
|||||||
elif subcmd == 'clear':
|
elif subcmd == 'clear':
|
||||||
await self.reset_counter(name)
|
await self.reset_counter(name)
|
||||||
await ctx.reply(f"'{name}' counter reset.")
|
await ctx.reply(f"'{name}' counter reset.")
|
||||||
|
elif subcmd == 'alias':
|
||||||
|
splits = args.split(maxsplit=3) if args else []
|
||||||
|
counter = await self.fetch_counter(name)
|
||||||
|
rows = await self.data.CounterCommand.fetch_where(counterid=counter.counterid)
|
||||||
|
existing = rows[0] if rows else None
|
||||||
|
if existing and not args:
|
||||||
|
# Show current alias
|
||||||
|
await ctx.reply(
|
||||||
|
f"Counter '{name}' aliases: '!{existing.name}' to add to counter; "
|
||||||
|
f"'!{existing.lbname}' to view counter leaderboard; "
|
||||||
|
f"'!{existing.undoname}' to undo (your) last addition."
|
||||||
|
)
|
||||||
|
elif len(splits) < 4:
|
||||||
|
# Show usage
|
||||||
|
await ctx.reply(
|
||||||
|
"USAGE: !counter <name> alias <cmdname> <lbname> <undoname> <response> -- "
|
||||||
|
"Response accepts keywords {total}, {period}, {period_name}, {detailsorname}, {user_total}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new alias
|
||||||
|
cmdname, lbname, undoname, response = splits
|
||||||
|
# Remove any existing alias
|
||||||
|
await self.data.CounterCommand.table.delete_where(name=cmdname)
|
||||||
|
|
||||||
|
alias = await self.data.CounterCommand.create(
|
||||||
|
name=cmdname,
|
||||||
|
counterid=counter.counterid,
|
||||||
|
lbname=lbname, undoname=undoname, response=response
|
||||||
|
)
|
||||||
|
await self.load_counter_commands()
|
||||||
|
await ctx.reply(
|
||||||
|
f"Alias created for counter '{name}': '!{alias.name}' to add to counter; "
|
||||||
|
f"'!{alias.lbname}' to view counter leaderboard; "
|
||||||
|
f"'!{alias.undoname}' to undo (your) last addition."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
|
||||||
|
|
||||||
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
||||||
if periodstr:
|
if periodstr:
|
||||||
@@ -211,82 +345,3 @@ class CounterCog(LionCog):
|
|||||||
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
||||||
else:
|
else:
|
||||||
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||||
|
|
||||||
# Misc actual counter commands
|
|
||||||
# TODO: Factor this out to a different module...
|
|
||||||
@commands.command()
|
|
||||||
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'tea'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: tea'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'coffee'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: coffee'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'water'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: water'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def stuff(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
await ctx.reply(f"Stuff {args}")
|
|
||||||
|
|
||||||
@cmds.hybrid_command('water')
|
|
||||||
async def d_water_cmd(self, ctx):
|
|
||||||
await ctx.reply(repr(ctx))
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ class CounterData(Registry):
|
|||||||
CREATE TABLE counters(
|
CREATE TABLE counters(
|
||||||
counterid SERIAL PRIMARY KEY,
|
counterid SERIAL PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
category TEXT
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||||
"""
|
"""
|
||||||
@@ -19,6 +20,7 @@ class CounterData(Registry):
|
|||||||
|
|
||||||
counterid = Integer(primary=True)
|
counterid = Integer(primary=True)
|
||||||
name = String()
|
name = String()
|
||||||
|
category = String()
|
||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
|
|
||||||
class CounterEntry(RowModel):
|
class CounterEntry(RowModel):
|
||||||
@@ -31,7 +33,8 @@ class CounterData(Registry):
|
|||||||
userid INTEGER NOT NULL,
|
userid INTEGER NOT NULL,
|
||||||
value INTEGER NOT NULL,
|
value INTEGER NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
context_str TEXT
|
context_str TEXT,
|
||||||
|
details TEXT
|
||||||
);
|
);
|
||||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||||
"""
|
"""
|
||||||
@@ -44,5 +47,28 @@ class CounterData(Registry):
|
|||||||
value = Integer()
|
value = Integer()
|
||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
context_str = String()
|
context_str = String()
|
||||||
|
details = String()
|
||||||
|
|
||||||
|
class CounterCommand(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE counter_commands(
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
lbname TEXT,
|
||||||
|
undoname TEXT,
|
||||||
|
response TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
# NOTE: This table will be replaced by aliases soon anyway
|
||||||
|
# So no need to worry about integrity or future-proofing
|
||||||
|
_tablename_ = 'counter_commands'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
name = String(primary=True)
|
||||||
|
counterid = Integer()
|
||||||
|
lbname = String()
|
||||||
|
undoname = String()
|
||||||
|
response = String()
|
||||||
|
|
||||||
|
|||||||
9
src/modules/counters/migration.sql
Normal file
9
src/modules/counters/migration.sql
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
ALTER TABLE counters ADD COLUMN category TEXT;
|
||||||
|
ALTER TABLE counter_log ADD COLUMN details TEXT;
|
||||||
|
CREATE TABLE counter_commands(
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
lbname TEXT,
|
||||||
|
undoname TEXT,
|
||||||
|
response TEXT NOT NULL
|
||||||
|
);
|
||||||
@@ -4,17 +4,21 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from attr import dataclass
|
import discord
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
from discord import app_commands as appcmds
|
||||||
|
|
||||||
import twitchio
|
import twitchio
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
|
|
||||||
from meta import CrocBot, LionCog
|
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||||
from meta.LionBot import LionBot
|
|
||||||
from meta.sockets import Channel, register_channel
|
from meta.sockets import Channel, register_channel
|
||||||
from utils.lib import strfdelta, utc_now
|
from utils.lib import strfdelta, utc_now
|
||||||
from . import logger
|
from . import logger
|
||||||
from .data import NowListData
|
from .data import NowListData
|
||||||
|
|
||||||
|
from modules.profiles.profile import UserProfile
|
||||||
|
|
||||||
|
|
||||||
class NowDoingChannel(Channel):
|
class NowDoingChannel(Channel):
|
||||||
name = 'NowList'
|
name = 'NowList'
|
||||||
@@ -25,19 +29,7 @@ class NowDoingChannel(Channel):
|
|||||||
|
|
||||||
async def on_connection(self, websocket, event):
|
async def on_connection(self, websocket, event):
|
||||||
await super().on_connection(websocket, event)
|
await super().on_connection(websocket, event)
|
||||||
for task in self.cog.tasks.values():
|
await self.reload_tasklist(websocket=websocket)
|
||||||
await self.send_set(*self.task_args(task), websocket=websocket)
|
|
||||||
|
|
||||||
async def send_test_set(self):
|
|
||||||
tasks = [
|
|
||||||
(0, 'Tester0', "Testing Tasklist", True),
|
|
||||||
(1, 'Tester1', "Getting Confused", False),
|
|
||||||
(2, "Tester2", "Generating Bugs", True),
|
|
||||||
(3, "Tester3", "Fixing Bugs", False),
|
|
||||||
(4, "Tester4", "Pushing the red button", False),
|
|
||||||
]
|
|
||||||
for task in tasks:
|
|
||||||
await self.send_set(*task)
|
|
||||||
|
|
||||||
def task_args(self, task: NowListData.Task):
|
def task_args(self, task: NowListData.Task):
|
||||||
return (
|
return (
|
||||||
@@ -48,6 +40,14 @@ class NowDoingChannel(Channel):
|
|||||||
task.done_at.isoformat() if task.done_at else None,
|
task.done_at.isoformat() if task.done_at else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def reload_tasklist(self, websocket=None):
|
||||||
|
"""
|
||||||
|
Clear tasklist and re-send current tasks.
|
||||||
|
"""
|
||||||
|
await self.send_clear(websocket=websocket)
|
||||||
|
for task in self.cog.tasks.values():
|
||||||
|
await self.send_set(*self.task_args(task), websocket=websocket)
|
||||||
|
|
||||||
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
|
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
|
||||||
await self.send_event({
|
await self.send_event({
|
||||||
'type': "DO",
|
'type': "DO",
|
||||||
@@ -61,28 +61,28 @@ class NowDoingChannel(Channel):
|
|||||||
}
|
}
|
||||||
}, websocket=websocket)
|
}, websocket=websocket)
|
||||||
|
|
||||||
async def send_del(self, userid):
|
async def send_del(self, userid, websocket=None):
|
||||||
await self.send_event({
|
await self.send_event({
|
||||||
'type': "DO",
|
'type': "DO",
|
||||||
'method': "delTask",
|
'method': "delTask",
|
||||||
'args': {
|
'args': {
|
||||||
'userid': userid,
|
'userid': userid,
|
||||||
}
|
}
|
||||||
})
|
}, websocket=websocket)
|
||||||
|
|
||||||
async def send_clear(self):
|
async def send_clear(self, websocket=None):
|
||||||
await self.send_event({
|
await self.send_event({
|
||||||
'type': "DO",
|
'type': "DO",
|
||||||
'method': "clearTasks",
|
'method': "clearTasks",
|
||||||
'args': {
|
'args': {
|
||||||
}
|
}
|
||||||
})
|
}, websocket=websocket)
|
||||||
|
|
||||||
|
|
||||||
class NowDoingCog(LionCog):
|
class NowDoingCog(LionCog):
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.crocbot = bot.crocbot
|
self.crocbot: CrocBot = bot.crocbot
|
||||||
self.data = bot.db.load_registry(NowListData())
|
self.data = bot.db.load_registry(NowListData())
|
||||||
self.channel = NowDoingChannel(self)
|
self.channel = NowDoingChannel(self)
|
||||||
register_channel(self.channel.name, self.channel)
|
register_channel(self.channel.name, self.channel)
|
||||||
@@ -94,17 +94,82 @@ class NowDoingCog(LionCog):
|
|||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
|
|
||||||
await self.load_tasks()
|
await self.load_tasks()
|
||||||
|
|
||||||
|
self.bot.get_cog('ProfileCog').add_profile_migrator(self.migrate_profiles, name='task-migrator')
|
||||||
|
|
||||||
self._load_twitch_methods(self.crocbot)
|
self._load_twitch_methods(self.crocbot)
|
||||||
self.loaded.set()
|
self.loaded.set()
|
||||||
|
|
||||||
async def cog_unload(self):
|
async def cog_unload(self):
|
||||||
self.loaded.clear()
|
self.loaded.clear()
|
||||||
self.tasks.clear()
|
self.tasks.clear()
|
||||||
|
if profiles := self.bot.get_cog('ProfileCog'):
|
||||||
|
profiles.del_profile_migrator('task-migrator')
|
||||||
self._unload_twitch_methods(self.crocbot)
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
|
||||||
|
"""
|
||||||
|
Move current source task to target profile if there's room for it, otherwise annihilate
|
||||||
|
"""
|
||||||
|
await self.load_tasks()
|
||||||
|
source_task = self.tasks.pop(source_profile.profileid, None)
|
||||||
|
|
||||||
|
results = ["(Tasklist)"]
|
||||||
|
|
||||||
|
if source_task:
|
||||||
|
target_task = self.tasks.get(target_profile.profileid, None)
|
||||||
|
if target_task and (target_task.is_done or target_task.started_at < source_task.started_at):
|
||||||
|
# If target is done, remove it so we can overwrite
|
||||||
|
results.append("Removed older task from target profile.")
|
||||||
|
await target_task.delete()
|
||||||
|
target_task = None
|
||||||
|
|
||||||
|
if not target_task:
|
||||||
|
# Update source task with new profile id
|
||||||
|
await source_task.update(userid=target_profile.profileid)
|
||||||
|
target_task = source_task
|
||||||
|
await self.channel.send_set(*self.channel.task_args(target_task))
|
||||||
|
results.append("Migrated 1 currently running task from source profile.")
|
||||||
|
else:
|
||||||
|
# If there is a target task we can't overwrite, just delete the source task
|
||||||
|
await source_task.delete()
|
||||||
|
results.append("Ignoring and removing older task from source profile.")
|
||||||
|
|
||||||
|
self.tasks.pop(source_profile.profileid, None)
|
||||||
|
await self.channel.send_del(source_profile.profileid)
|
||||||
|
else:
|
||||||
|
results.append("No running task in source profile, nothing to migrate!")
|
||||||
|
await self.load_tasks()
|
||||||
|
|
||||||
|
return ' '.join(results)
|
||||||
|
|
||||||
|
async def user_profile_migration(self):
|
||||||
|
"""
|
||||||
|
Manual single-use migration method from the old userid format to the new profileid format.
|
||||||
|
"""
|
||||||
|
await self.load_tasks()
|
||||||
|
for userid, task in self.tasks.items():
|
||||||
|
userid = int(userid)
|
||||||
|
if userid > 1000:
|
||||||
|
# Assume it is a twitch userid
|
||||||
|
profile = await UserProfile.fetch_from_twitchid(self.bot, userid)
|
||||||
|
|
||||||
|
if not profile:
|
||||||
|
# Create a new profile with this twitch user
|
||||||
|
users = await self.crocbot.fetch_users(ids=[userid])
|
||||||
|
if not users:
|
||||||
|
continue
|
||||||
|
user = users[0]
|
||||||
|
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||||
|
|
||||||
|
if not await self.data.Task.fetch(profile.profileid):
|
||||||
|
await task.update(userid=profile.profileid)
|
||||||
|
else:
|
||||||
|
await task.delete()
|
||||||
|
await self.load_tasks()
|
||||||
|
await self.channel.reload_tasklist()
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
if not self.loaded.is_set():
|
if not self.loaded.is_set():
|
||||||
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
||||||
@@ -123,25 +188,27 @@ class NowDoingCog(LionCog):
|
|||||||
# await self.channel.send_test_set()
|
# await self.channel.send_test_set()
|
||||||
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
|
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
|
||||||
# await ctx.send(str(list(self.tasks.items())[0]))
|
# await ctx.send(str(list(self.tasks.items())[0]))
|
||||||
|
await self.user_profile_migration()
|
||||||
await ctx.send(str(ctx.author.id))
|
await ctx.send(str(ctx.author.id))
|
||||||
|
await ctx.reply("Userid -> profile migration done.")
|
||||||
else:
|
else:
|
||||||
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
|
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
|
||||||
|
|
||||||
@commands.command(aliases=['task', 'check'])
|
async def now(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str] = None, edit=False):
|
||||||
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
args = args.strip() if args else None
|
args = args.strip() if args else None
|
||||||
|
userid = profile.profileid
|
||||||
if args:
|
if args:
|
||||||
|
existing = self.tasks.get(userid, None)
|
||||||
await self.data.Task.table.delete_where(userid=userid)
|
await self.data.Task.table.delete_where(userid=userid)
|
||||||
task = await self.data.Task.create(
|
task = await self.data.Task.create(
|
||||||
userid=userid,
|
userid=userid,
|
||||||
name=ctx.author.display_name,
|
name=ctx.author.display_name,
|
||||||
task=args,
|
task=args,
|
||||||
started_at=utc_now(),
|
started_at=existing.started_at if (existing and edit) else utc_now(),
|
||||||
)
|
)
|
||||||
self.tasks[task.userid] = task
|
self.tasks[task.userid] = task
|
||||||
await self.channel.send_set(*self.channel.task_args(task))
|
await self.channel.send_set(*self.channel.task_args(task))
|
||||||
await ctx.send(f"Updated your current task, good luck!")
|
await ctx.send("Updated your current task, good luck!")
|
||||||
elif task := self.tasks.get(userid, None):
|
elif task := self.tasks.get(userid, None):
|
||||||
if task.is_done:
|
if task.is_done:
|
||||||
done_ago = strfdelta(utc_now() - task.done_at)
|
done_ago = strfdelta(utc_now() - task.done_at)
|
||||||
@@ -159,9 +226,38 @@ class NowDoingCog(LionCog):
|
|||||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.command(name='next')
|
@commands.command(
|
||||||
async def nownext(self, ctx: commands.Context, *, args: Optional[str] = None):
|
name='now',
|
||||||
userid = int(ctx.author.id)
|
aliases=['task', 'check']
|
||||||
|
)
|
||||||
|
async def twi_now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||||
|
await self.now(ctx, profile, args)
|
||||||
|
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='now',
|
||||||
|
aliases=['task', 'check']
|
||||||
|
)
|
||||||
|
async def disc_now(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
await self.now(ctx, profile, args)
|
||||||
|
|
||||||
|
@commands.command(
|
||||||
|
name='edit',
|
||||||
|
)
|
||||||
|
async def twi_edit(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||||
|
await self.now(ctx, profile, args, edit=True)
|
||||||
|
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='edit',
|
||||||
|
)
|
||||||
|
async def disc_edit(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
await self.now(ctx, profile, args, edit=True)
|
||||||
|
|
||||||
|
async def nownext(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str]):
|
||||||
|
userid = profile.profileid
|
||||||
task = self.tasks.get(userid, None)
|
task = self.tasks.get(userid, None)
|
||||||
if args:
|
if args:
|
||||||
if task:
|
if task:
|
||||||
@@ -182,7 +278,7 @@ class NowDoingCog(LionCog):
|
|||||||
)
|
)
|
||||||
self.tasks[task.userid] = task
|
self.tasks[task.userid] = task
|
||||||
await self.channel.send_set(*self.channel.task_args(task))
|
await self.channel.send_set(*self.channel.task_args(task))
|
||||||
await ctx.send(f"Next task set, good luck!" + ' ' + prefix)
|
await ctx.send("Next task set, good luck!" + ' ' + prefix)
|
||||||
elif task:
|
elif task:
|
||||||
if task.is_done:
|
if task.is_done:
|
||||||
done_ago = strfdelta(utc_now() - task.done_at)
|
done_ago = strfdelta(utc_now() - task.done_at)
|
||||||
@@ -200,9 +296,22 @@ class NowDoingCog(LionCog):
|
|||||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command(
|
||||||
async def done(self, ctx: commands.Context):
|
name='next',
|
||||||
userid = int(ctx.author.id)
|
)
|
||||||
|
async def twi_next(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||||
|
await self.nownext(ctx, profile, args)
|
||||||
|
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='next',
|
||||||
|
)
|
||||||
|
async def disc_next(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
await self.nownext(ctx, profile, args)
|
||||||
|
|
||||||
|
async def done(self, ctx: commands.Context | LionContext, profile: UserProfile):
|
||||||
|
userid = profile.profileid
|
||||||
if task := self.tasks.get(userid, None):
|
if task := self.tasks.get(userid, None):
|
||||||
if task.is_done:
|
if task.is_done:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
@@ -222,9 +331,36 @@ class NowDoingCog(LionCog):
|
|||||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command(
|
||||||
async def clear(self, ctx: commands.Context):
|
name='done',
|
||||||
userid = int(ctx.author.id)
|
)
|
||||||
|
async def twi_done(self, ctx: commands.Context):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||||
|
await self.done(ctx, profile)
|
||||||
|
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='done',
|
||||||
|
)
|
||||||
|
async def disc_done(self, ctx: LionContext):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
await self.done(ctx, profile)
|
||||||
|
|
||||||
|
@commands.command(
|
||||||
|
name='clear',
|
||||||
|
)
|
||||||
|
async def twi_clear(self, ctx: commands.Context):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||||
|
await self.clear(ctx, profile)
|
||||||
|
|
||||||
|
@cmds.hybrid_command(
|
||||||
|
name='clear',
|
||||||
|
)
|
||||||
|
async def disc_clear(self, ctx: LionContext):
|
||||||
|
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||||
|
await self.clear(ctx, profile)
|
||||||
|
|
||||||
|
async def clear(self, ctx: commands.Context | LionContext, profile):
|
||||||
|
userid = profile.profileid
|
||||||
if task := self.tasks.pop(userid, None):
|
if task := self.tasks.pop(userid, None):
|
||||||
await task.delete()
|
await task.delete()
|
||||||
await self.channel.send_del(userid)
|
await self.channel.send_del(userid)
|
||||||
|
|||||||
@@ -47,16 +47,32 @@ class TimerChannel(Channel):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.cog = cog
|
self.cog = cog
|
||||||
|
|
||||||
|
self.channelid = 1261999440160624734
|
||||||
|
self.goal = 12
|
||||||
|
|
||||||
async def on_connection(self, websocket, event):
|
async def on_connection(self, websocket, event):
|
||||||
await super().on_connection(websocket, event)
|
await super().on_connection(websocket, event)
|
||||||
timer = self.cog.get_channel_timer(1261999440160624734)
|
await self.send_set(
|
||||||
if timer is not None:
|
**await self.get_args_for(self.channelid),
|
||||||
await self.send_set(
|
goal=self.goal,
|
||||||
timer.data.last_started,
|
websocket=websocket,
|
||||||
timer.data.focus_length,
|
)
|
||||||
timer.data.break_length,
|
|
||||||
websocket=websocket,
|
async def send_updates(self):
|
||||||
)
|
await self.send_set(
|
||||||
|
**await self.get_args_for(self.channelid),
|
||||||
|
goal=self.goal,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_args_for(self, channelid):
|
||||||
|
timer = self.cog.get_channel_timer(channelid)
|
||||||
|
if timer is None:
|
||||||
|
raise ValueError(f"Timer {channelid} doesn't exist.")
|
||||||
|
return {
|
||||||
|
'start_at': timer.data.last_started,
|
||||||
|
'focus_length': timer.data.focus_length,
|
||||||
|
'break_length': timer.data.break_length,
|
||||||
|
}
|
||||||
|
|
||||||
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
|
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
|
||||||
await self.send_event({
|
await self.send_event({
|
||||||
@@ -304,8 +320,6 @@ class TimerCog(LionCog):
|
|||||||
return
|
return
|
||||||
if member.bot:
|
if member.bot:
|
||||||
return
|
return
|
||||||
if 1148167212901859328 not in [role.id for role in member.roles]:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If a member is leaving or joining a running timer, trigger a status update
|
# If a member is leaving or joining a running timer, trigger a status update
|
||||||
if before.channel != after.channel:
|
if before.channel != after.channel:
|
||||||
@@ -315,6 +329,7 @@ class TimerCog(LionCog):
|
|||||||
tasks = []
|
tasks = []
|
||||||
if leaving is not None:
|
if leaving is not None:
|
||||||
tasks.append(asyncio.create_task(leaving.update_status_card()))
|
tasks.append(asyncio.create_task(leaving.update_status_card()))
|
||||||
|
leaving.last_seen.pop(member.id, None)
|
||||||
if joining is not None:
|
if joining is not None:
|
||||||
joining.last_seen[member.id] = utc_now()
|
joining.last_seen[member.id] = utc_now()
|
||||||
if not joining.running and joining.auto_restart:
|
if not joining.running and joining.auto_restart:
|
||||||
@@ -1059,8 +1074,18 @@ class TimerCog(LionCog):
|
|||||||
@low_management_ward
|
@low_management_ward
|
||||||
async def streamtimer_update_cmd(self, ctx: LionContext,
|
async def streamtimer_update_cmd(self, ctx: LionContext,
|
||||||
new_start: Optional[str] = None,
|
new_start: Optional[str] = None,
|
||||||
new_goal: int = 12):
|
new_goal: Optional[int] = None,
|
||||||
timer = self.get_channel_timer(1261999440160624734)
|
new_channel: Optional[discord.VoiceChannel] = None,
|
||||||
|
):
|
||||||
|
if new_channel is not None:
|
||||||
|
channelid = self.channel.channelid = new_channel.id
|
||||||
|
else:
|
||||||
|
channelid = self.channel.channelid
|
||||||
|
|
||||||
|
if new_goal is not None:
|
||||||
|
self.channel.goal = new_goal
|
||||||
|
|
||||||
|
timer = self.get_channel_timer(channelid)
|
||||||
if timer is None:
|
if timer is None:
|
||||||
return
|
return
|
||||||
if new_start:
|
if new_start:
|
||||||
@@ -1068,10 +1093,5 @@ class TimerCog(LionCog):
|
|||||||
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
|
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
|
||||||
await timer.data.update(last_started=start_at)
|
await timer.data.update(last_started=start_at)
|
||||||
|
|
||||||
await self.channel.send_set(
|
await self.channel.send_updates()
|
||||||
timer.data.last_started,
|
|
||||||
timer.data.focus_length,
|
|
||||||
timer.data.break_length,
|
|
||||||
goal=new_goal,
|
|
||||||
)
|
|
||||||
await ctx.reply("Stream Timer Updated")
|
await ctx.reply("Stream Timer Updated")
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ from gui.cards import FocusTimerCard, BreakTimerCard
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .timer import Timer, Stage
|
from .timer import Timer, Stage
|
||||||
from tracking.voice.cog import VoiceTrackerCog
|
from tracking.voice.cog import VoiceTrackerCog
|
||||||
|
from modules.nowdoing.cog import NowDoingCog
|
||||||
|
|
||||||
|
|
||||||
async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
|
async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
|
||||||
voicecog: 'VoiceTrackerCog' = bot.get_cog('VoiceTrackerCog')
|
voicecog: 'VoiceTrackerCog' = bot.get_cog('VoiceTrackerCog')
|
||||||
|
nowcog: 'NowDoingCog' = bot.get_cog('NowDoingCog')
|
||||||
|
|
||||||
name = timer.base_name
|
name = timer.base_name
|
||||||
if stage is not None:
|
if stage is not None:
|
||||||
@@ -23,16 +25,22 @@ async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
|
|||||||
card_users = []
|
card_users = []
|
||||||
guildid = timer.data.guildid
|
guildid = timer.data.guildid
|
||||||
for member in timer.members:
|
for member in timer.members:
|
||||||
if voicecog is not None:
|
profile = await bot.get_cog('ProfileCog').fetch_profile_discord(member)
|
||||||
session = voicecog.get_session(guildid, member.id)
|
task = nowcog.tasks.get(profile.profileid, None)
|
||||||
tag = session.tag
|
tag = ''
|
||||||
if session.start_time:
|
session_duration = 0
|
||||||
session_duration = (utc_now() - session.start_time).total_seconds()
|
|
||||||
else:
|
if task:
|
||||||
session_duration = 0
|
tag = task.task
|
||||||
|
session_duration = ((task.done_at or utc_now()) - task.started_at).total_seconds()
|
||||||
else:
|
else:
|
||||||
session_duration = 0
|
session = voicecog.get_session(guildid, member.id)
|
||||||
tag = None
|
if session:
|
||||||
|
tag = session.tag
|
||||||
|
if session.start_time:
|
||||||
|
session_duration = (utc_now() - session.start_time).total_seconds()
|
||||||
|
else:
|
||||||
|
session_duration = 0
|
||||||
|
|
||||||
card_user = (
|
card_user = (
|
||||||
(member.id, (member.avatar or member.default_avatar).key),
|
(member.id, (member.avatar or member.default_avatar).key),
|
||||||
|
|||||||
8
src/modules/profiles/__init__.py
Normal file
8
src/modules/profiles/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from .cog import ProfileCog
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
await bot.add_cog(ProfileCog(bot))
|
||||||
415
src/modules/profiles/cog.py
Normal file
415
src/modules/profiles/cog.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, overload
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import app_commands as appcmds
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio import User
|
||||||
|
from twitchAPI.object.api import TwitchUser
|
||||||
|
|
||||||
|
|
||||||
|
from data.queries import ORDER
|
||||||
|
from meta import LionCog, LionBot, CrocBot, LionContext
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
from .profile import UserProfile
|
||||||
|
from .community import Community
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileCog(LionCog):
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
assert bot.crocbot is not None
|
||||||
|
self.crocbot: CrocBot = bot.crocbot
|
||||||
|
self.data = bot.db.load_registry(ProfileData())
|
||||||
|
|
||||||
|
self._profile_migrators = {}
|
||||||
|
self._comm_migrators = {}
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
async def cog_check(self, ctx):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def bot_check_once(self, ctx: LionContext):
|
||||||
|
"""
|
||||||
|
Inject the contextual UserProfile and Community into the LionContext.
|
||||||
|
|
||||||
|
Creates the profile and community if they do not exist.
|
||||||
|
"""
|
||||||
|
if ctx.guild:
|
||||||
|
ctx.community = await self.fetch_community_discord(ctx.guild)
|
||||||
|
ctx.profile = await self.fetch_profile_discord(ctx.author)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Profile API
|
||||||
|
def add_profile_migrator(self, migrator, name=None):
|
||||||
|
name = name or migrator.__name__
|
||||||
|
self._profile_migrators[name or migrator.__name__] = migrator
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Added user profile migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
return migrator
|
||||||
|
|
||||||
|
def del_profile_migrator(self, name: str):
|
||||||
|
migrator = self._profile_migrators.pop(name, None)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Removed user profile migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_wrap(action="profile migration")
|
||||||
|
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
|
||||||
|
logger.info(
|
||||||
|
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
# Wrap this in a transaction so if something goes wrong with migration,
|
||||||
|
# we roll back safely (although this may mess up caches)
|
||||||
|
async with self.bot.db.connection() as conn:
|
||||||
|
self.bot.db.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
for name, migrator in self._profile_migrators.items():
|
||||||
|
try:
|
||||||
|
result = await migrator(source_profile, target_profile)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception running user profile migrator {name} "
|
||||||
|
f"migrating {source_profile!r} to {target_profile!r}."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Move all Discord and Twitch profile references over to the new profile
|
||||||
|
discord_rows = await self.data.DiscordProfileRow.table.update_where(
|
||||||
|
profileid=source_profile.profileid
|
||||||
|
).set(profileid=target_profile.profileid)
|
||||||
|
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
|
||||||
|
|
||||||
|
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
|
||||||
|
profileid=source_profile.profileid
|
||||||
|
).set(profileid=target_profile.profileid)
|
||||||
|
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
|
||||||
|
|
||||||
|
# And then mark the old profile as migrated
|
||||||
|
await source_profile.profile_row.update(migrated=target_profile.profileid)
|
||||||
|
results.append("Marking old profile as migrated.. finished!")
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch a UserProfile by the given id.
|
||||||
|
"""
|
||||||
|
return await UserProfile.fetch(self.bot, profile_id=profile_id)
|
||||||
|
|
||||||
|
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch or create a UserProfile from the provided discord account.
|
||||||
|
"""
|
||||||
|
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
|
||||||
|
if profile is None:
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, user)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
|
||||||
|
"""
|
||||||
|
Fetch or create a UserProfile from the provided twitch account.
|
||||||
|
"""
|
||||||
|
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
if profile is None:
|
||||||
|
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
# Community API
|
||||||
|
def add_community_migrator(self, migrator, name=None):
|
||||||
|
name = name or migrator.__name__
|
||||||
|
self._comm_migrators[name or migrator.__name__] = migrator
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Added community migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
return migrator
|
||||||
|
|
||||||
|
def del_community_migrator(self, name: str):
|
||||||
|
migrator = self._comm_migrators.pop(name, None)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Removed community migrator {name}: {migrator}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@log_wrap(action="community migration")
|
||||||
|
async def migrate_community(self, source_comm, target_comm) -> list[str]:
|
||||||
|
logger.info(
|
||||||
|
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
# Wrap this in a transaction so if something goes wrong with migration,
|
||||||
|
# we roll back safely (although this may mess up caches)
|
||||||
|
async with self.bot.db.connection() as conn:
|
||||||
|
self.bot.db.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
for name, migrator in self._comm_migrators.items():
|
||||||
|
try:
|
||||||
|
result = await migrator(source_comm, target_comm)
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception running community migrator {name} "
|
||||||
|
f"migrating {source_comm!r} to {target_comm!r}."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Move all Discord and Twitch community preferences over to the new profile
|
||||||
|
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
|
||||||
|
profileid=source_comm.communityid
|
||||||
|
).set(communityid=target_comm.communityid)
|
||||||
|
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
|
||||||
|
|
||||||
|
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
|
||||||
|
communityid=source_comm.communityid
|
||||||
|
).set(communityid=target_comm.communityid)
|
||||||
|
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
|
||||||
|
|
||||||
|
# And then mark the old community as migrated
|
||||||
|
await source_comm.update(migrated=target_comm.communityid)
|
||||||
|
results.append("Marking old community as migrated.. finished!")
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def fetch_community_by_id(self, community_id: int) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch a Community by the given id.
|
||||||
|
"""
|
||||||
|
return await Community.fetch(self.bot, community_id=community_id)
|
||||||
|
|
||||||
|
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch or create a Community from the provided discord guild.
|
||||||
|
"""
|
||||||
|
comm = await Community.fetch_from_discordid(self.bot, guild.id)
|
||||||
|
if comm is None:
|
||||||
|
comm = await Community.create_from_discord(self.bot, guild)
|
||||||
|
return comm
|
||||||
|
|
||||||
|
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
|
||||||
|
"""
|
||||||
|
Fetch or create a Community from the provided twitch account.
|
||||||
|
"""
|
||||||
|
community = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
if community is None:
|
||||||
|
community = await Community.create_from_twitch(self.bot, user)
|
||||||
|
return community
|
||||||
|
|
||||||
|
# ----- Profile Commands -----
|
||||||
|
@cmds.hybrid_group(
|
||||||
|
name='profiles',
|
||||||
|
description="Base comand group for user profiles."
|
||||||
|
)
|
||||||
|
async def profiles_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@profiles_grp.group(
|
||||||
|
name='link',
|
||||||
|
description="Base command group for linking profiles"
|
||||||
|
)
|
||||||
|
async def profiles_link_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@profiles_link_grp.command(
|
||||||
|
name='twitch',
|
||||||
|
description="Link a twitch account to your current profile."
|
||||||
|
)
|
||||||
|
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
# Ask the user to go through auth to get their userid
|
||||||
|
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||||
|
flow = await auth_cog.start_auth()
|
||||||
|
message = await ctx.reply(
|
||||||
|
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
|
||||||
|
"to Twitch."
|
||||||
|
)
|
||||||
|
authrow = await flow.run()
|
||||||
|
await message.edit(
|
||||||
|
content="Authentication Complete! Beginning profile merge..."
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||||
|
if not results:
|
||||||
|
logger.error(
|
||||||
|
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
user = results[0]
|
||||||
|
|
||||||
|
# Retrieve author's profile if it exists
|
||||||
|
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
|
||||||
|
|
||||||
|
# Check if the twitch-side user has a profile
|
||||||
|
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
|
||||||
|
if author_profile and source_profile is None:
|
||||||
|
# All we need to do is attach the twitch row
|
||||||
|
await author_profile.attach_twitch(user)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully added Twitch account **{user.name}**! There was no profile data to merge."
|
||||||
|
)
|
||||||
|
elif source_profile and author_profile is None:
|
||||||
|
# Attach the discord row to the profile
|
||||||
|
await source_profile.attach_discord(ctx.author)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully connected to Twitch profile **{user.name}**! There was no profile data to merge."
|
||||||
|
)
|
||||||
|
elif source_profile is None and author_profile is None:
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
|
||||||
|
await profile.attach_twitch(user)
|
||||||
|
|
||||||
|
await message.edit(
|
||||||
|
content=f"Opened a new user profile for you and linked Twitch account **{user.name}**."
|
||||||
|
)
|
||||||
|
elif author_profile.profileid == source_profile.profileid:
|
||||||
|
await message.edit(
|
||||||
|
content=f"The Twitch account **{user.name}** is already linked to your profile!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Migrate the existing profile data to the new profiles
|
||||||
|
try:
|
||||||
|
results = await self.migrate_profile(source_profile, author_profile)
|
||||||
|
except Exception:
|
||||||
|
await ctx.error_reply(
|
||||||
|
"An issue was encountered while merging your account profiles!\n"
|
||||||
|
"Migration rolled back, no data has been lost.\n"
|
||||||
|
"The developer has been notified. Please try again later!"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
content = '\n'.join((
|
||||||
|
"## Connecting Twitch account and merging profiles...",
|
||||||
|
*results,
|
||||||
|
"**Successfully linked account and merged profile data!**"
|
||||||
|
))
|
||||||
|
await message.edit(content=content)
|
||||||
|
|
||||||
|
# ----- Community Commands -----
|
||||||
|
@cmds.hybrid_group(
|
||||||
|
name='community',
|
||||||
|
description="Base comand group for community profiles."
|
||||||
|
)
|
||||||
|
async def community_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@community_grp.group(
|
||||||
|
name='link',
|
||||||
|
description="Base command group for linking communities"
|
||||||
|
)
|
||||||
|
async def community_link_grp(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@community_link_grp.command(
|
||||||
|
name='twitch',
|
||||||
|
description="Link a twitch account to this community."
|
||||||
|
)
|
||||||
|
@appcmds.guild_only()
|
||||||
|
@appcmds.default_permissions(manage_guild=True)
|
||||||
|
async def comm_link_twitch_cmd(self, ctx: LionContext):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
assert ctx.guild is not None
|
||||||
|
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
if not ctx.author.guild_permissions.manage_guild:
|
||||||
|
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ask the user to go through auth to get their userid
|
||||||
|
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||||
|
flow = await auth_cog.start_auth(
|
||||||
|
scopes=[
|
||||||
|
AuthScope.CHAT_EDIT,
|
||||||
|
AuthScope.CHAT_READ,
|
||||||
|
AuthScope.MODERATION_READ,
|
||||||
|
AuthScope.CHANNEL_BOT,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
message = await ctx.reply(
|
||||||
|
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
|
||||||
|
)
|
||||||
|
authrow = await flow.run()
|
||||||
|
await message.edit(
|
||||||
|
content="Authentication Complete! Beginning community profile merge..."
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||||
|
if not results:
|
||||||
|
logger.error(
|
||||||
|
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||||
|
)
|
||||||
|
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||||
|
return
|
||||||
|
|
||||||
|
user = results[0]
|
||||||
|
|
||||||
|
# Retrieve author's profile if it exists
|
||||||
|
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
|
||||||
|
|
||||||
|
# Check if the twitch-side user has a profile
|
||||||
|
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||||
|
|
||||||
|
if guild_comm and twitch_comm is None:
|
||||||
|
# All we need to do is attach the twitch row
|
||||||
|
await guild_comm.attach_twitch(user)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully linked Twitch channel **{user.name}**! There was no community data to merge."
|
||||||
|
)
|
||||||
|
elif twitch_comm and guild_comm is None:
|
||||||
|
# Attach the discord row to the profile
|
||||||
|
await twitch_comm.attach_discord(ctx.guild)
|
||||||
|
await message.edit(
|
||||||
|
content=f"Successfully connected to Twitch channel **{user.name}**!"
|
||||||
|
)
|
||||||
|
elif twitch_comm is None and guild_comm is None:
|
||||||
|
profile = await Community.create_from_discord(self.bot, ctx.guild)
|
||||||
|
await profile.attach_twitch(user)
|
||||||
|
|
||||||
|
await message.edit(
|
||||||
|
content=f"Created a new community for this server and linked Twitch account **{user.name}**."
|
||||||
|
)
|
||||||
|
elif guild_comm.communityid == twitch_comm.communityid:
|
||||||
|
await message.edit(
|
||||||
|
content=f"This server is already linked to the Twitch channel **{user.name}**!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Migrate the existing profile data to the new profiles
|
||||||
|
try:
|
||||||
|
results = await self.migrate_community(twitch_comm, guild_comm)
|
||||||
|
except Exception:
|
||||||
|
await ctx.error_reply(
|
||||||
|
"An issue was encountered while merging your community profiles!\n"
|
||||||
|
"Migration rolled back, no data has been lost.\n"
|
||||||
|
"The developer has been notified. Please try again later!"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
content = '\n'.join((
|
||||||
|
"## Connecting Twitch account and merging community profiles...",
|
||||||
|
*results,
|
||||||
|
"**Successfully linked account and merged community data!**"
|
||||||
|
))
|
||||||
|
await message.edit(content=content)
|
||||||
123
src/modules/profiles/community.py
Normal file
123
src/modules/profiles/community.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import twitchio
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Community:
|
||||||
|
def __init__(self, bot: LionBot, community_row):
|
||||||
|
self.bot = bot
|
||||||
|
self.row: ProfileData.CommunityRow = community_row
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cog(self):
|
||||||
|
return self.bot.get_cog('ProfileCog')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> ProfileData:
|
||||||
|
return self.cog.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def communityid(self):
|
||||||
|
return self.row.communityid
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Community communityid={self.communityid} row={self.row}>"
|
||||||
|
|
||||||
|
async def attach_discord(self, guild: discord.Guild):
|
||||||
|
"""
|
||||||
|
Attach a new discord guild to this community.
|
||||||
|
Assumes the discord guild is not already associated to a community.
|
||||||
|
"""
|
||||||
|
discord_row = await self.data.DiscordCommunityRow.create(
|
||||||
|
communityid=self.communityid,
|
||||||
|
guildid=guild.id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached discord guild {guild!r} to community {self!r}"
|
||||||
|
)
|
||||||
|
return discord_row
|
||||||
|
|
||||||
|
async def attach_twitch(self, user: twitchio.User):
|
||||||
|
"""
|
||||||
|
Attach a new Twitch user channel to this community.
|
||||||
|
"""
|
||||||
|
twitch_row = await self.data.TwitchCommunityRow.create(
|
||||||
|
communityid=self.communityid,
|
||||||
|
channelid=str(user.id)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached twitch channel {user!r} to community {self!r}"
|
||||||
|
)
|
||||||
|
return twitch_row
|
||||||
|
|
||||||
|
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Discord guild rows associated to this community.
|
||||||
|
"""
|
||||||
|
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
|
||||||
|
|
||||||
|
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Twitch user rows associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
|
||||||
|
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
|
||||||
|
if community_row is None:
|
||||||
|
raise ValueError("Provided community_id does not exist.")
|
||||||
|
return cls(bot, community_row)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new empty community with the given initial arguments.
|
||||||
|
|
||||||
|
Communities should usually be created using `create_from_discord` or `create_from_twitch`
|
||||||
|
to correctly setup initial preferences (e.g. name, avatar).
|
||||||
|
"""
|
||||||
|
# Create a new community
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
|
||||||
|
return await cls.fetch(bot, row.communityid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new community using the given Discord guild as a base.
|
||||||
|
"""
|
||||||
|
self = await cls.create(bot, **kwargs)
|
||||||
|
await self.attach_discord(guild)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Twitch channel user as a base.
|
||||||
|
"""
|
||||||
|
self = await cls.create(bot, **kwargs)
|
||||||
|
await self.attach_twitch(user)
|
||||||
|
return self
|
||||||
158
src/modules/profiles/data.py
Normal file
158
src/modules/profiles/data.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from data import Registry, RowModel
|
||||||
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileData(Registry):
|
||||||
|
class UserProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE user_profiles(
|
||||||
|
profileid SERIAL PRIMARY KEY,
|
||||||
|
nickname TEXT,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'user_profiles'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
profileid = Integer(primary=True)
|
||||||
|
nickname = String()
|
||||||
|
migrated = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE profiles_discord(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'profiles_discord'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
linkid = Integer(primary=True)
|
||||||
|
profileid = Integer()
|
||||||
|
userid = Integer()
|
||||||
|
created_at = Integer()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_profile(cls, profileid: int):
|
||||||
|
rows = await cls.fetch_where(profiled=profileid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchProfileRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE profiles_twitch(
|
||||||
|
linkid SERIAL PRIMARY KEY,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
userid TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||||
|
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'profiles_twitch'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
linkid = Integer(primary=True)
|
||||||
|
profileid = Integer()
|
||||||
|
userid = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_profile(cls, profileid: int):
|
||||||
|
rows = await cls.fetch_where(profiled=profileid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class CommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities(
|
||||||
|
communityid SERIAL PRIMARY KEY,
|
||||||
|
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
communityid = Integer(primary=True)
|
||||||
|
migrated = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
class DiscordCommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities_discord(
|
||||||
|
guildid BIGINT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities_discord'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
guildid = Integer(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
linked_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_community(cls, communityid: int):
|
||||||
|
rows = await cls.fetch_where(communityd=communityid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class TwitchCommunityRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE communities_twitch(
|
||||||
|
channelid TEXT PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities_twitch'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
channelid = String(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
linked_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_community(cls, communityid: int):
|
||||||
|
rows = await cls.fetch_where(communityd=communityid)
|
||||||
|
return next(rows, None)
|
||||||
|
|
||||||
|
class CommunityMemberRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE community_members(
|
||||||
|
memberid SERIAL PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'community_members'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
memberid = Integer(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
profileid = Integer()
|
||||||
|
created_at = Timestamp()
|
||||||
124
src/modules/profiles/profile.py
Normal file
124
src/modules/profiles/profile.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import twitchio
|
||||||
|
|
||||||
|
from meta import LionBot
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .data import ProfileData
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile:
|
||||||
|
def __init__(self, bot: LionBot, profile_row):
|
||||||
|
self.bot = bot
|
||||||
|
self.profile_row: ProfileData.UserProfileRow = profile_row
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cog(self):
|
||||||
|
return self.bot.get_cog('ProfileCog')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> ProfileData:
|
||||||
|
return self.cog.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profileid(self):
|
||||||
|
return self.profile_row.profileid
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
|
||||||
|
|
||||||
|
async def attach_discord(self, user: discord.User | discord.Member):
|
||||||
|
"""
|
||||||
|
Attach a new discord user to this profile.
|
||||||
|
Assumes the discord user does not itself have a profile.
|
||||||
|
"""
|
||||||
|
discord_row = await self.data.DiscordProfileRow.create(
|
||||||
|
profileid=self.profileid,
|
||||||
|
userid=user.id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached discord user {user!r} to profile {self!r}"
|
||||||
|
)
|
||||||
|
return discord_row
|
||||||
|
|
||||||
|
async def attach_twitch(self, user: twitchio.User):
|
||||||
|
"""
|
||||||
|
Attach a new Twitch user to this profile.
|
||||||
|
"""
|
||||||
|
twitch_row = await self.data.TwitchProfileRow.create(
|
||||||
|
profileid=self.profileid,
|
||||||
|
userid=str(user.id)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Attached twitch user {user!r} to profile {self!r}"
|
||||||
|
)
|
||||||
|
return twitch_row
|
||||||
|
|
||||||
|
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Discord accounts associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.DiscordProfileRow.fetch_where(profileid=self.profileid)
|
||||||
|
|
||||||
|
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
|
||||||
|
"""
|
||||||
|
Fetch the Twitch accounts associated to this profile.
|
||||||
|
"""
|
||||||
|
return await self.data.TwitchProfileRow.fetch_where(profileid=self.profileid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
|
||||||
|
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
|
||||||
|
if profile_row is None:
|
||||||
|
raise ValueError("Provided profile_id does not exist.")
|
||||||
|
return cls(bot, profile_row)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].profileid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
|
||||||
|
if rows:
|
||||||
|
return await cls.fetch(bot, rows[0].profileid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new empty profile with the given initial arguments.
|
||||||
|
|
||||||
|
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
|
||||||
|
to correctly setup initial profile preferences (e.g. name, avatar).
|
||||||
|
"""
|
||||||
|
# Create a new profile
|
||||||
|
data = bot.get_cog('ProfileCog').data
|
||||||
|
profile_row = await data.UserProfileRow.create(created_at=utc_now())
|
||||||
|
profile = await cls.fetch(bot, profile_row.profileid)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Discord user as a base.
|
||||||
|
"""
|
||||||
|
profile = await cls.create(bot, **kwargs)
|
||||||
|
await profile.attach_discord(user)
|
||||||
|
return profile
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
|
||||||
|
"""
|
||||||
|
Create a new profile using the given Twitch user as a base.
|
||||||
|
"""
|
||||||
|
profile = await cls.create(bot, **kwargs)
|
||||||
|
await profile.attach_twitch(user)
|
||||||
|
return profile
|
||||||
@@ -17,9 +17,19 @@ class ShoutoutCog(LionCog):
|
|||||||
and drop a follow! \
|
and drop a follow! \
|
||||||
They {areorwere} streaming {game} at {channel}
|
They {areorwere} streaming {game} at {channel}
|
||||||
"""
|
"""
|
||||||
|
COWO_SHOUTOUT = """
|
||||||
|
We think that {name} is a great coworker and you should check them out for more productive vibes! \
|
||||||
|
They {areorwere} streaming {game} at {channel}
|
||||||
|
"""
|
||||||
|
ART_SHOUTOUT = """
|
||||||
|
We think that {name} is an awesome artist and you should check them out for cool art and cosy vibes! \
|
||||||
|
They {areorwere} streaming {game} at {channel}
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.crocbot = bot.crocbot
|
self.crocbot: CrocBot = bot.crocbot
|
||||||
|
|
||||||
self.data = bot.db.load_registry(ShoutoutData())
|
self.data = bot.db.load_registry(ShoutoutData())
|
||||||
|
|
||||||
self.loaded = asyncio.Event()
|
self.loaded = asyncio.Event()
|
||||||
@@ -59,19 +69,28 @@ class ShoutoutCog(LionCog):
|
|||||||
return replace_multiple(text, mapping)
|
return replace_multiple(text, mapping)
|
||||||
|
|
||||||
@commands.command(aliases=['so'])
|
@commands.command(aliases=['so'])
|
||||||
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
|
async def shoutout(self, ctx: commands.Context, target: str, typ: Optional[str]=None):
|
||||||
# Make sure caller is mod/broadcaster
|
# Make sure caller is mod/broadcaster
|
||||||
# Lookup custom shoutout for this user
|
# Lookup custom shoutout for this user
|
||||||
# If it exists use it, otherwise use default shoutout
|
# If it exists use it, otherwise use default shoutout
|
||||||
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||||
data = await self.data.CustomShoutout.fetch(int(user.id))
|
user = await self.crocbot.seek_user(target)
|
||||||
if data:
|
if user is None:
|
||||||
shoutout = data.content
|
await ctx.reply(f"Couldn't resolve '{target}' to a valid user.")
|
||||||
else:
|
else:
|
||||||
shoutout = self.DEFAULT_SHOUTOUT
|
data = await self.data.CustomShoutout.fetch(int(user.id))
|
||||||
formatted = await self.format_shoutout(shoutout, user)
|
if data:
|
||||||
await ctx.reply(formatted)
|
shoutout = data.content
|
||||||
|
elif typ == 'cowo':
|
||||||
|
shoutout = self.COWO_SHOUTOUT
|
||||||
|
elif typ == 'art':
|
||||||
|
shoutout = self.ART_SHOUTOUT
|
||||||
|
else:
|
||||||
|
shoutout = self.DEFAULT_SHOUTOUT
|
||||||
|
formatted = await self.format_shoutout(shoutout, user)
|
||||||
|
await ctx.reply(formatted)
|
||||||
# TODO: How to /shoutout with lib?
|
# TODO: How to /shoutout with lib?
|
||||||
|
# TODO Shoutout queue
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):
|
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):
|
||||||
|
|||||||
@@ -7,9 +7,12 @@ from discord.ext import commands as cmds
|
|||||||
from discord import app_commands as appcmds
|
from discord import app_commands as appcmds
|
||||||
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
|
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
|
||||||
|
|
||||||
|
from data.queries import JOINTYPE
|
||||||
from meta import LionBot, LionCog, LionContext
|
from meta import LionBot, LionCog, LionContext
|
||||||
|
from meta.CrocBot import CrocBot
|
||||||
from meta.logger import log_wrap
|
from meta.logger import log_wrap
|
||||||
from meta.errors import UserInputError
|
from meta.errors import UserInputError
|
||||||
|
from modules.profiles.profile import UserProfile
|
||||||
from utils.lib import utc_now, error_embed
|
from utils.lib import utc_now, error_embed
|
||||||
from utils.ui import ChoicedEnum, Transformed, AButton
|
from utils.ui import ChoicedEnum, Transformed, AButton
|
||||||
|
|
||||||
@@ -126,6 +129,7 @@ class TasklistCog(LionCog):
|
|||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.crocbot: CrocBot = bot.crocbot
|
||||||
self.data = bot.db.load_registry(TasklistData())
|
self.data = bot.db.load_registry(TasklistData())
|
||||||
self.babel = babel
|
self.babel = babel
|
||||||
self.settings = TasklistSettings()
|
self.settings = TasklistSettings()
|
||||||
@@ -138,10 +142,84 @@ class TasklistCog(LionCog):
|
|||||||
self.bot.core.guild_config.register_model_setting(self.settings.task_reward_limit)
|
self.bot.core.guild_config.register_model_setting(self.settings.task_reward_limit)
|
||||||
self.bot.add_view(TasklistCaller(self.bot))
|
self.bot.add_view(TasklistCaller(self.bot))
|
||||||
|
|
||||||
|
self.bot.profiles.add_profile_migrator(self.migrate_profiles, name='tasklist-migrator')
|
||||||
|
|
||||||
configcog = self.bot.get_cog('ConfigCog')
|
configcog = self.bot.get_cog('ConfigCog')
|
||||||
self.crossload_group(self.configure_group, configcog.config_group)
|
self.crossload_group(self.configure_group, configcog.config_group)
|
||||||
|
|
||||||
@LionCog.listener('on_tasks_completed')
|
self._load_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
async def cog_unload(self):
|
||||||
|
self.live_tasklists.clear()
|
||||||
|
if profiles := self.bot.get_cog('ProfileCog'):
|
||||||
|
profiles.del_profile_migrator('tasklist-migrator')
|
||||||
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
@log_wrap(action="Tasklist Profile Migration")
|
||||||
|
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
|
||||||
|
"""
|
||||||
|
Re-assign all tasklist tasks from source profile to target profile.
|
||||||
|
TODO: Probably wants some elegant handling of the cached or running tasklists.
|
||||||
|
"""
|
||||||
|
results = ["(Tasklist)"]
|
||||||
|
sourceid = source_profile.profileid
|
||||||
|
targetid = target_profile.profileid
|
||||||
|
updated = await self.data.Task.table.update_where(userid=sourceid).set(userid=targetid)
|
||||||
|
if updated:
|
||||||
|
results.append(
|
||||||
|
f"Migrated {len(updated)} task row(s) from source profile."
|
||||||
|
)
|
||||||
|
for channel_lists in self.live_tasklists.get(sourceid, []):
|
||||||
|
for tasklist in list(channel_lists.values()):
|
||||||
|
await tasklist.close()
|
||||||
|
self.bot.dispatch('tasklist_update', profileid=targetid, summon=False)
|
||||||
|
else:
|
||||||
|
results.append(
|
||||||
|
"No tasks found in source profile, nothing to migrate!"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ' '.join(results)
|
||||||
|
|
||||||
|
async def user_profile_migration(self):
|
||||||
|
"""
|
||||||
|
Manual one-shot migration method from old Discord userids to the new profileids.
|
||||||
|
"""
|
||||||
|
# First collect all the distinct userids from the tasklist
|
||||||
|
# Then create a map of userids to profileids, creating the profiles if required
|
||||||
|
# Then do updates, we can just inefficiently do updates on each distinct userid
|
||||||
|
# As long as the userids and profileids never overlap, this is fine. Fine for a one-shot
|
||||||
|
|
||||||
|
# Extract all the userids that exist in the table
|
||||||
|
rows = await self.data.Task.table.select_where().select(
|
||||||
|
userid="DISTINCT(userid)"
|
||||||
|
).with_no_adapter()
|
||||||
|
|
||||||
|
# Fetch or create discord user profiles for them
|
||||||
|
profile_map = {}
|
||||||
|
for row in rows:
|
||||||
|
userid = row['userid']
|
||||||
|
if userid > 100000:
|
||||||
|
# Assume a Discord snowflake
|
||||||
|
profile = await UserProfile.fetch_from_discordid(self.bot, userid)
|
||||||
|
|
||||||
|
if not profile:
|
||||||
|
try:
|
||||||
|
user = self.bot.get_user(userid)
|
||||||
|
if user is None:
|
||||||
|
user = await self.bot.fetch_user(userid)
|
||||||
|
except discord.HTTPException:
|
||||||
|
logger.info(f"Skipping user {userid}")
|
||||||
|
continue
|
||||||
|
profile = await UserProfile.create_from_discord(self.bot, user)
|
||||||
|
profile_map[userid] = profile
|
||||||
|
|
||||||
|
# Now iterate through
|
||||||
|
for userid, profile in profile_map.items():
|
||||||
|
logger.info(f"Migrating userid {userid} to profile {profile}")
|
||||||
|
await self.data.Task.table.update_where(userid=userid).set(userid=profile.profileid)
|
||||||
|
|
||||||
|
# Temporarily disabling integration with userid driven Economy
|
||||||
|
# @LionCog.listener('on_tasks_completed')
|
||||||
@log_wrap(action="reward tasks completed")
|
@log_wrap(action="reward tasks completed")
|
||||||
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
|
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
|
||||||
async with self.bot.db.connection() as conn:
|
async with self.bot.db.connection() as conn:
|
||||||
@@ -170,6 +248,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def is_tasklist_channel(self, channel) -> bool:
|
async def is_tasklist_channel(self, channel) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a given Discord channel is a tasklist channel
|
||||||
|
"""
|
||||||
if not channel.guild:
|
if not channel.guild:
|
||||||
return True
|
return True
|
||||||
channels = (await self.settings.tasklist_channels.get(channel.guild.id)).value
|
channels = (await self.settings.tasklist_channels.get(channel.guild.id)).value
|
||||||
@@ -186,12 +267,16 @@ class TasklistCog(LionCog):
|
|||||||
return (channel in channels) or (channel.id in private_channels) or (channel.category in channels)
|
return (channel in channels) or (channel.id in private_channels) or (channel.category in channels)
|
||||||
|
|
||||||
async def call_tasklist(self, interaction: discord.Interaction):
|
async def call_tasklist(self, interaction: discord.Interaction):
|
||||||
|
"""
|
||||||
|
Given a Discord channel interaction, summon the interacting user's tasklist.
|
||||||
|
"""
|
||||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
channel = interaction.channel
|
channel = interaction.channel
|
||||||
guild = channel.guild
|
guild = channel.guild
|
||||||
userid = interaction.user.id
|
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
|
||||||
|
profileid = profile.profileid
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, userid)
|
tasklist = await Tasklist.fetch(self.bot, self.data, profileid)
|
||||||
|
|
||||||
if await self.is_tasklist_channel(channel):
|
if await self.is_tasklist_channel(channel):
|
||||||
# Check we have permissions to send a regular message here
|
# Check we have permissions to send a regular message here
|
||||||
@@ -213,7 +298,7 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await interaction.edit_original_response(embed=error)
|
await interaction.edit_original_response(embed=error)
|
||||||
else:
|
else:
|
||||||
tasklistui = TasklistUI.fetch(tasklist, channel, guild, timeout=None)
|
tasklistui = TasklistUI.fetch(tasklist, channel, guild, caller=interaction.user, timeout=None)
|
||||||
await tasklistui.summon(force=True)
|
await tasklistui.summon(force=True)
|
||||||
await interaction.delete_original_response()
|
await interaction.delete_original_response()
|
||||||
else:
|
else:
|
||||||
@@ -222,14 +307,14 @@ class TasklistCog(LionCog):
|
|||||||
await tasklistui.run(interaction)
|
await tasklistui.run(interaction)
|
||||||
|
|
||||||
@LionCog.listener('on_tasklist_update')
|
@LionCog.listener('on_tasklist_update')
|
||||||
async def update_listening_tasklists(self, userid, channel=None, summon=True):
|
async def update_listening_tasklists(self, profileid, channel=None, summon=True):
|
||||||
"""
|
"""
|
||||||
Propagate a tasklist update to all persistent tasklist UIs for this user.
|
Propagate a tasklist update to all persistent tasklist UIs for this user.
|
||||||
|
|
||||||
If channel is given, also summons the UI if the channel is a tasklist channel.
|
If channel is given, also summons the UI if the channel is a tasklist channel.
|
||||||
"""
|
"""
|
||||||
# Do the given channel first, and summon if requested
|
# Do the given channel first, and summon if requested
|
||||||
if channel and (tui := TasklistUI._live_[userid].get(channel.id, None)) is not None:
|
if channel and (tui := TasklistUI._live_[profileid].get(channel.id, None)) is not None:
|
||||||
try:
|
try:
|
||||||
if summon and await self.is_tasklist_channel(channel):
|
if summon and await self.is_tasklist_channel(channel):
|
||||||
await tui.summon()
|
await tui.summon()
|
||||||
@@ -240,7 +325,7 @@ class TasklistCog(LionCog):
|
|||||||
await tui.close()
|
await tui.close()
|
||||||
|
|
||||||
# Now do the rest of the listening channels
|
# Now do the rest of the listening channels
|
||||||
listening = TasklistUI._live_[userid]
|
listening = TasklistUI._live_[profileid]
|
||||||
for cid, ui in list(listening.items()):
|
for cid, ui in list(listening.items()):
|
||||||
if channel and channel.id == cid:
|
if channel and channel.id == cid:
|
||||||
# We already did this channel
|
# We already did this channel
|
||||||
@@ -275,7 +360,7 @@ class TasklistCog(LionCog):
|
|||||||
async def tasklist_group(self, ctx: LionContext):
|
async def tasklist_group(self, ctx: LionContext):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _task_acmpl(self, userid: int, partial: str, multi=False) -> list[appcmds.Choice]:
|
async def _task_acmpl(self, profileid: int, partial: str, multi=False) -> list[appcmds.Choice]:
|
||||||
"""
|
"""
|
||||||
Generate a list of task Choices matching a given partial string.
|
Generate a list of task Choices matching a given partial string.
|
||||||
|
|
||||||
@@ -284,7 +369,7 @@ class TasklistCog(LionCog):
|
|||||||
t = self.bot.translator.t
|
t = self.bot.translator.t
|
||||||
|
|
||||||
# Should usually be cached, so this won't trigger repetitive db access
|
# Should usually be cached, so this won't trigger repetitive db access
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, userid)
|
tasklist = await Tasklist.fetch(self.bot, self.data, profileid)
|
||||||
|
|
||||||
# Special case for an empty tasklist
|
# Special case for an empty tasklist
|
||||||
if not tasklist.tasklist:
|
if not tasklist.tasklist:
|
||||||
@@ -392,13 +477,17 @@ class TasklistCog(LionCog):
|
|||||||
"""
|
"""
|
||||||
Shared autocomplete for single task parameters.
|
Shared autocomplete for single task parameters.
|
||||||
"""
|
"""
|
||||||
return await self._task_acmpl(interaction.user.id, partial, multi=False)
|
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
|
||||||
|
profileid = profile.profileid
|
||||||
|
return await self._task_acmpl(profileid, partial, multi=False)
|
||||||
|
|
||||||
async def tasks_acmpl(self, interaction: discord.Interaction, partial: str) -> list[appcmds.Choice]:
|
async def tasks_acmpl(self, interaction: discord.Interaction, partial: str) -> list[appcmds.Choice]:
|
||||||
"""
|
"""
|
||||||
Shared autocomplete for multiple task parameters.
|
Shared autocomplete for multiple task parameters.
|
||||||
"""
|
"""
|
||||||
return await self._task_acmpl(interaction.user.id, partial, multi=True)
|
profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
|
||||||
|
profileid = profile.profileid
|
||||||
|
return await self._task_acmpl(profileid, partial, multi=True)
|
||||||
|
|
||||||
@tasklist_group.command(
|
@tasklist_group.command(
|
||||||
name=_p('cmd:tasks_new', "new"),
|
name=_p('cmd:tasks_new', "new"),
|
||||||
@@ -422,7 +511,7 @@ class TasklistCog(LionCog):
|
|||||||
if not ctx.interaction:
|
if not ctx.interaction:
|
||||||
return
|
return
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
# Fetch parent task if required
|
# Fetch parent task if required
|
||||||
@@ -453,9 +542,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
tasklist_new_cmd.autocomplete('parent')(task_acmpl)
|
tasklist_new_cmd.autocomplete('parent')(task_acmpl)
|
||||||
|
|
||||||
@@ -523,7 +612,7 @@ class TasklistCog(LionCog):
|
|||||||
raise UserInputError(error)
|
raise UserInputError(error)
|
||||||
|
|
||||||
# Contents successfully parsed, update the tasklist.
|
# Contents successfully parsed, update the tasklist.
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
|
|
||||||
taskinfo = tasklist.parse_tasklist(lines)
|
taskinfo = tasklist.parse_tasklist(lines)
|
||||||
|
|
||||||
@@ -572,9 +661,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
@tasklist_group.command(
|
@tasklist_group.command(
|
||||||
name=_p('cmd:tasks_edit', "edit"),
|
name=_p('cmd:tasks_edit', "edit"),
|
||||||
@@ -600,7 +689,7 @@ class TasklistCog(LionCog):
|
|||||||
t = self.bot.translator.t
|
t = self.bot.translator.t
|
||||||
if not ctx.interaction:
|
if not ctx.interaction:
|
||||||
return
|
return
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
|
|
||||||
# Fetch task to edit
|
# Fetch task to edit
|
||||||
tid = tasklist.parse_label(taskstr) if taskstr else None
|
tid = tasklist.parse_label(taskstr) if taskstr else None
|
||||||
@@ -651,12 +740,12 @@ class TasklistCog(LionCog):
|
|||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=(
|
view=(
|
||||||
discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.author.id]
|
discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid]
|
||||||
else TasklistCaller(self.bot)
|
else TasklistCaller(self.bot)
|
||||||
),
|
),
|
||||||
ephemeral=True
|
ephemeral=True
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
if new_content or new_parent:
|
if new_content or new_parent:
|
||||||
# Manual edit route
|
# Manual edit route
|
||||||
@@ -688,17 +777,17 @@ class TasklistCog(LionCog):
|
|||||||
async def tasklist_clear_cmd(self, ctx: LionContext):
|
async def tasklist_clear_cmd(self, ctx: LionContext):
|
||||||
t = ctx.bot.translator.t
|
t = ctx.bot.translator.t
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
await tasklist.update_tasklist(deleted_at=utc_now())
|
await tasklist.update_tasklist(deleted_at=utc_now())
|
||||||
await ctx.reply(
|
await ctx.reply(
|
||||||
t(_p(
|
t(_p(
|
||||||
'cmd:tasks_clear|resp:success',
|
'cmd:tasks_clear|resp:success',
|
||||||
"Your tasklist has been cleared."
|
"Your tasklist has been cleared."
|
||||||
)),
|
)),
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot),
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot),
|
||||||
ephemeral=True
|
ephemeral=True
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
@tasklist_group.command(
|
@tasklist_group.command(
|
||||||
name=_p('cmd:tasks_remove', "remove"),
|
name=_p('cmd:tasks_remove', "remove"),
|
||||||
@@ -748,7 +837,7 @@ class TasklistCog(LionCog):
|
|||||||
|
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
|
|
||||||
conditions = []
|
conditions = []
|
||||||
if taskidstr:
|
if taskidstr:
|
||||||
@@ -784,7 +873,7 @@ class TasklistCog(LionCog):
|
|||||||
elif completed is False:
|
elif completed is False:
|
||||||
conditions.append(self.data.Task.completed_at == NULL)
|
conditions.append(self.data.Task.completed_at == NULL)
|
||||||
|
|
||||||
tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.author.id)
|
tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.profile.profileid)
|
||||||
if not tasks:
|
if not tasks:
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=error_embed(t(_p(
|
embed=error_embed(t(_p(
|
||||||
@@ -813,9 +902,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
tasklist_remove_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
tasklist_remove_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
||||||
|
|
||||||
@@ -844,7 +933,7 @@ class TasklistCog(LionCog):
|
|||||||
|
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
taskids = tasklist.parse_labels(taskidstr)
|
taskids = tasklist.parse_labels(taskidstr)
|
||||||
@@ -889,9 +978,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
tasklist_tick_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
tasklist_tick_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
||||||
|
|
||||||
@@ -920,7 +1009,7 @@ class TasklistCog(LionCog):
|
|||||||
|
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
taskids = tasklist.parse_labels(taskidstr)
|
taskids = tasklist.parse_labels(taskidstr)
|
||||||
@@ -962,9 +1051,9 @@ class TasklistCog(LionCog):
|
|||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(
|
await ctx.interaction.edit_original_response(
|
||||||
embed=embed,
|
embed=embed,
|
||||||
view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
|
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot)
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
|
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel)
|
||||||
|
|
||||||
tasklist_untick_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
tasklist_untick_cmd.autocomplete('taskidstr')(tasks_acmpl)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from data.columns import Integer, String, Timestamp, Bool
|
|||||||
|
|
||||||
|
|
||||||
class TasklistData(Registry):
|
class TasklistData(Registry):
|
||||||
|
|
||||||
class Task(RowModel):
|
class Task(RowModel):
|
||||||
"""
|
"""
|
||||||
Row model describing a single task in a tasklist.
|
Row model describing a single task in a tasklist.
|
||||||
@@ -14,21 +15,17 @@ class TasklistData(Registry):
|
|||||||
CREATE TABLE tasklist(
|
CREATE TABLE tasklist(
|
||||||
taskid SERIAL PRIMARY KEY,
|
taskid SERIAL PRIMARY KEY,
|
||||||
userid BIGINT NOT NULL REFERENCES user_config ON DELETE CASCADE,
|
userid BIGINT NOT NULL REFERENCES user_config ON DELETE CASCADE,
|
||||||
|
profileid INTEGER NOT NULL REFERENCES user_profiles ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
parentid INTEGER REFERENCES tasklist (taskid) ON DELETE SET NULL,
|
parentid INTEGER REFERENCES tasklist (taskid) ON DELETE SET NULL,
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL,
|
||||||
rewarded BOOL DEFAULT FALSE,
|
rewarded BOOL DEFAULT FALSE,
|
||||||
deleted_at TIMESTAMPTZ,
|
deleted_at TIMESTAMPTZ,
|
||||||
completed_at TIMESTAMPTZ,
|
completed_at TIMESTAMPTZ,
|
||||||
created_at TIMESTAMPTZ,
|
created_at TIMESTAMPTZ,
|
||||||
|
duration INTEGER,
|
||||||
last_updated_at TIMESTAMPTZ
|
last_updated_at TIMESTAMPTZ
|
||||||
);
|
);
|
||||||
CREATE INDEX tasklist_users ON tasklist (userid);
|
CREATE INDEX tasklist_users ON tasklist (userid);
|
||||||
|
|
||||||
CREATE TABLE tasklist_channels(
|
|
||||||
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
|
|
||||||
channelid BIGINT NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
|
|
||||||
"""
|
"""
|
||||||
_tablename_ = "tasklist"
|
_tablename_ = "tasklist"
|
||||||
|
|
||||||
@@ -41,5 +38,26 @@ class TasklistData(Registry):
|
|||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
deleted_at = Timestamp()
|
deleted_at = Timestamp()
|
||||||
last_updated_at = Timestamp()
|
last_updated_at = Timestamp()
|
||||||
|
duration = Integer()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
|
||||||
|
CREATE TABLE tasklist_channels(
|
||||||
|
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
|
||||||
|
channelid BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
|
||||||
|
"""
|
||||||
channels = Table('tasklist_channels')
|
channels = Table('tasklist_channels')
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE current_tasks(
|
||||||
|
taskid PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
last_started_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
current_tasks = Table('current_tasks')
|
||||||
|
|||||||
23
src/modules/tasklist/migration.sql
Normal file
23
src/modules/tasklist/migration.sql
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
ALTER TABLE tasklist
|
||||||
|
DROP CONSTRAINT fk_tasklist_users;
|
||||||
|
|
||||||
|
ALTER TABLE tasklist
|
||||||
|
ADD CONSTRAINT fk_tasklist_users
|
||||||
|
FOREIGN KEY (userid)
|
||||||
|
REFERENCES user_profiles (profileid)
|
||||||
|
ON DELETE CASCADE
|
||||||
|
ON UPDATE CASCADE
|
||||||
|
NOT VALID;
|
||||||
|
ALTER TABLE tasklist
|
||||||
|
ADD COLUMN duration INTEGER;
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE tasklist_current(
|
||||||
|
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE tasklist_planner(
|
||||||
|
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
sortkey INTEGER
|
||||||
|
);
|
||||||
@@ -232,13 +232,18 @@ class TasklistUI(BasePager):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tasklist: Tasklist,
|
tasklist: Tasklist,
|
||||||
channel: discord.abc.Messageable, guild: Optional[discord.Guild] = None, **kwargs):
|
channel: discord.abc.Messageable,
|
||||||
|
guild: Optional[discord.Guild] = None,
|
||||||
|
caller: Optional[discord.User | discord.Member] = None,
|
||||||
|
**kwargs):
|
||||||
kwargs.setdefault('timeout', 600)
|
kwargs.setdefault('timeout', 600)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.bot = tasklist.bot
|
self.bot = tasklist.bot
|
||||||
self.tasklist = tasklist
|
self.tasklist = tasklist
|
||||||
self.labelled = tasklist.labelled
|
self.labelled = tasklist.labelled
|
||||||
|
self.caller = caller
|
||||||
|
# NOTE: This is now a profiled
|
||||||
self.userid = tasklist.userid
|
self.userid = tasklist.userid
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.guild = guild
|
self.guild = guild
|
||||||
@@ -449,9 +454,10 @@ class TasklistUI(BasePager):
|
|||||||
cascade=True,
|
cascade=True,
|
||||||
completed_at=utc_now()
|
completed_at=utc_now()
|
||||||
)
|
)
|
||||||
if self.guild:
|
# TODO: Removed economy integration
|
||||||
if (member := self.guild.get_member(self.userid)):
|
# if self.guild:
|
||||||
self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
|
# if (member := self.guild.get_member(self.userid)):
|
||||||
|
# self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
|
||||||
if to_uncomplete:
|
if to_uncomplete:
|
||||||
await self.tasklist.update_tasks(
|
await self.tasklist.update_tasks(
|
||||||
*(t.taskid for t in to_uncomplete),
|
*(t.taskid for t in to_uncomplete),
|
||||||
@@ -475,7 +481,7 @@ class TasklistUI(BasePager):
|
|||||||
if shared_root:
|
if shared_root:
|
||||||
self._subtree_root = labelled[shared_root].taskid
|
self._subtree_root = labelled[shared_root].taskid
|
||||||
|
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
async def _delete_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
|
async def _delete_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
@@ -486,7 +492,7 @@ class TasklistUI(BasePager):
|
|||||||
cascade=True,
|
cascade=True,
|
||||||
deleted_at=utc_now()
|
deleted_at=utc_now()
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
async def _edit_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
|
async def _edit_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
|
||||||
if not selected.values:
|
if not selected.values:
|
||||||
@@ -513,7 +519,7 @@ class TasklistUI(BasePager):
|
|||||||
self._last_parentid = new_parentid
|
self._last_parentid = new_parentid
|
||||||
if not subtree:
|
if not subtree:
|
||||||
self._subtree_root = new_parentid
|
self._subtree_root = new_parentid
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
await interaction.response.send_modal(editor)
|
await interaction.response.send_modal(editor)
|
||||||
|
|
||||||
@@ -606,7 +612,7 @@ class TasklistUI(BasePager):
|
|||||||
self._subtree_root = pid
|
self._subtree_root = pid
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
await self.tasklist.create_task(new_task, parentid=pid)
|
await self.tasklist.create_task(new_task, parentid=pid)
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
await press.response.send_modal(editor)
|
await press.response.send_modal(editor)
|
||||||
|
|
||||||
@@ -667,7 +673,7 @@ class TasklistUI(BasePager):
|
|||||||
|
|
||||||
@editor.add_callback
|
@editor.add_callback
|
||||||
async def editor_callback(interaction: discord.Interaction):
|
async def editor_callback(interaction: discord.Interaction):
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
if sum(len(line) for line in editor.lines.values()) + len(editor.lines) >= 4000:
|
if sum(len(line) for line in editor.lines.values()) + len(editor.lines) >= 4000:
|
||||||
await press.response.send_message(
|
await press.response.send_message(
|
||||||
@@ -698,7 +704,7 @@ class TasklistUI(BasePager):
|
|||||||
await self.tasklist.update_tasklist(
|
await self.tasklist.update_tasklist(
|
||||||
deleted_at=utc_now(),
|
deleted_at=utc_now(),
|
||||||
)
|
)
|
||||||
self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
|
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False)
|
||||||
|
|
||||||
async def clear_button_refresh(self):
|
async def clear_button_refresh(self):
|
||||||
self.clear_button.label = self.bot.translator.t(_p(
|
self.clear_button.label = self.bot.translator.t(_p(
|
||||||
@@ -771,11 +777,12 @@ class TasklistUI(BasePager):
|
|||||||
|
|
||||||
# ----- UI Flow -----
|
# ----- UI Flow -----
|
||||||
def access_check(self, userid):
|
def access_check(self, userid):
|
||||||
return userid == self.userid
|
return userid in (self.userid, self.caller.id if self.caller else None)
|
||||||
|
|
||||||
async def interaction_check(self, interaction: discord.Interaction):
|
async def interaction_check(self, interaction: discord.Interaction):
|
||||||
t = self.bot.translator.t
|
t = self.bot.translator.t
|
||||||
if not self.access_check(interaction.user.id):
|
interaction_profile = await self.bot.profiles.fetch_profile_discord(interaction.user)
|
||||||
|
if not self.access_check(interaction_profile.profileid):
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
description=t(_p(
|
description=t(_p(
|
||||||
'ui:tasklist|error:wrong_user',
|
'ui:tasklist|error:wrong_user',
|
||||||
@@ -812,10 +819,7 @@ class TasklistUI(BasePager):
|
|||||||
total = len(tasks)
|
total = len(tasks)
|
||||||
completed = sum(t.completed_at is not None for t in tasks)
|
completed = sum(t.completed_at is not None for t in tasks)
|
||||||
|
|
||||||
if self.guild:
|
user = self.caller
|
||||||
user = self.guild.get_member(self.userid)
|
|
||||||
else:
|
|
||||||
user = self.bot.get_user(self.userid)
|
|
||||||
user_name = user.name if user else str(self.userid)
|
user_name = user.name if user else str(self.userid)
|
||||||
user_colour = user.colour if user else discord.Color.orange()
|
user_colour = user.colour if user else discord.Color.orange()
|
||||||
|
|
||||||
|
|||||||
Submodule src/modules/voicefix updated: 5146e46515...cca1c94bd5
7
src/modules/voiceroles/__init__.py
Normal file
7
src/modules/voiceroles/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
from .cog import VoiceRoleCog
|
||||||
|
await bot.add_cog(VoiceRoleCog(bot))
|
||||||
166
src/modules/voiceroles/cog.py
Normal file
166
src/modules/voiceroles/cog.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
import asyncio
|
||||||
|
from cachetools import FIFOCache
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.abc import GuildChannel
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
from discord import app_commands as appcmds
|
||||||
|
|
||||||
|
from meta import LionBot, LionCog, LionContext
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||||
|
from utils.ui import Confirm
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from .data import VoiceRoleData
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceRoleCog(LionCog):
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = bot.db.load_registry(VoiceRoleData())
|
||||||
|
|
||||||
|
self._event_locks: WeakValueDictionary[tuple[int, int], asyncio.Lock] = WeakValueDictionary()
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
@LionCog.listener('on_voice_state_update')
|
||||||
|
@log_wrap(action='Voice Role Update')
|
||||||
|
async def voicerole_update(self, member: discord.Member,
|
||||||
|
before: discord.VoiceState, after: discord.VoiceState):
|
||||||
|
if member.bot:
|
||||||
|
return
|
||||||
|
|
||||||
|
after_channel = after.channel
|
||||||
|
before_channel = before.channel
|
||||||
|
if after_channel == before_channel:
|
||||||
|
return
|
||||||
|
|
||||||
|
task_key = (member.guild.id, member.id)
|
||||||
|
async with self.event_lock(task_key):
|
||||||
|
# Get the roles of the channel they left to remove
|
||||||
|
# Get the roles of the channel they are joining to add
|
||||||
|
# Use a set difference to remove the roles to be added from the ones to remove
|
||||||
|
if before_channel is not None:
|
||||||
|
leaving_roles = await self.get_roles_for(before_channel.id)
|
||||||
|
else:
|
||||||
|
leaving_roles = []
|
||||||
|
|
||||||
|
if after_channel is not None:
|
||||||
|
gaining_roles = await self.get_roles_for(after_channel.id)
|
||||||
|
else:
|
||||||
|
gaining_roles = []
|
||||||
|
|
||||||
|
to_remove = []
|
||||||
|
for role in leaving_roles:
|
||||||
|
if role in member.roles and role not in gaining_roles and role.is_assignable():
|
||||||
|
to_remove.append(role)
|
||||||
|
|
||||||
|
to_add = []
|
||||||
|
for role in gaining_roles:
|
||||||
|
if role not in member.roles and role.is_assignable():
|
||||||
|
to_add.append(role)
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await member.remove_roles(*to_remove, reason="Removing voice channel associated roles.")
|
||||||
|
if to_add:
|
||||||
|
await member.add_roles(*to_add, reason="Adding voice channel associated roles.")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Voice roles removed {len(to_remove)} roles "
|
||||||
|
f"and added {len(to_add)} roles to <uid: {member.id}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_roles_for(self, channelid: int) -> list[discord.Role]:
|
||||||
|
"""
|
||||||
|
Get the voice roles associated to the given channel, as a list.
|
||||||
|
|
||||||
|
Returns an empty list if there are no associated voice roles.
|
||||||
|
"""
|
||||||
|
rows = await self.data.VoiceRole.fetch_where(channelid=channelid)
|
||||||
|
channel = self.bot.get_channel(channelid)
|
||||||
|
if not channel:
|
||||||
|
raise ValueError("Provided voice role target channel is not in cache.")
|
||||||
|
|
||||||
|
target_roles = []
|
||||||
|
for row in rows:
|
||||||
|
role = channel.guild.get_role(row.roleid)
|
||||||
|
if role is not None:
|
||||||
|
target_roles.append(role)
|
||||||
|
|
||||||
|
return target_roles
|
||||||
|
|
||||||
|
def event_lock(self, key) -> asyncio.Lock:
|
||||||
|
"""
|
||||||
|
Get an asyncio.Lock for the given key.
|
||||||
|
|
||||||
|
Guarantees sequential event handling.
|
||||||
|
"""
|
||||||
|
lock = self._event_locks.get(key, None)
|
||||||
|
if lock is None:
|
||||||
|
lock = self._event_locks[key] = asyncio.Lock()
|
||||||
|
logger.debug(f"Getting video event lock {key} (locked: {lock.locked()})")
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
# -------- Commands --------
|
||||||
|
@cmds.hybrid_group(
|
||||||
|
name='voiceroles',
|
||||||
|
description="Base command group for voice channel -> role associationes."
|
||||||
|
)
|
||||||
|
@appcmds.default_permissions(manage_channels=True)
|
||||||
|
async def voicerole_group(self, ctx: LionContext):
|
||||||
|
...
|
||||||
|
|
||||||
|
@voicerole_group.command(
|
||||||
|
name="link",
|
||||||
|
description="Link a given voice channel with a given role."
|
||||||
|
)
|
||||||
|
@appcmds.describe(
|
||||||
|
channel="The voice channel to link.",
|
||||||
|
role="The associated role to give to members joining the voice channel."
|
||||||
|
)
|
||||||
|
async def voicerole_link(self, ctx: LionContext,
|
||||||
|
channel: discord.VoiceChannel,
|
||||||
|
role: discord.Role):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
if not channel.permissions_for(ctx.author).manage_channels:
|
||||||
|
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
|
||||||
|
return
|
||||||
|
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
|
||||||
|
await ctx.error_reply(f"You don't have the permission to manage this role!")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.data.VoiceRole.table.insert(channelid=channel.id, roleid=role.id)
|
||||||
|
await ctx.reply("Voice role associated!")
|
||||||
|
|
||||||
|
@voicerole_group.command(
|
||||||
|
name="unlink",
|
||||||
|
description="Unlink a given voice channel from a given role."
|
||||||
|
)
|
||||||
|
@appcmds.describe(
|
||||||
|
channel="The voice channel to unlink.",
|
||||||
|
role="The role to remove from this voice channel."
|
||||||
|
)
|
||||||
|
async def voicerole_unlink(self, ctx: LionContext,
|
||||||
|
channel: discord.VoiceChannel,
|
||||||
|
role: discord.Role):
|
||||||
|
if not ctx.interaction:
|
||||||
|
return
|
||||||
|
if not channel.permissions_for(ctx.author).manage_channels:
|
||||||
|
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
|
||||||
|
return
|
||||||
|
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
|
||||||
|
await ctx.error_reply(f"You don't have the permission to manage this role!")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.data.VoiceRole.table.delete_where(channelid=channel.id, roleid=role.id)
|
||||||
|
await ctx.reply("Voice role disassociated!")
|
||||||
|
|
||||||
|
# TODO: Display and visual editing of roles.
|
||||||
|
|
||||||
27
src/modules/voiceroles/data.py
Normal file
27
src/modules/voiceroles/data.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from data import Registry, RowModel
|
||||||
|
from data.columns import Integer, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceRoleData(Registry):
|
||||||
|
class VoiceRole(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE voice_roles(
|
||||||
|
voice_role_id SERIAL PRIMARY KEY,
|
||||||
|
channelid BIGINT NOT NULL,
|
||||||
|
roleid BIGINT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX voice_role_channels on voice_roles (channelid);
|
||||||
|
"""
|
||||||
|
# TODO: Worth associating a guildid to this as well? Denormalises though
|
||||||
|
# Makes more theoretical sense to associated configurable channels to the guilds in a join table.
|
||||||
|
_tablename_ = 'voice_roles'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
voice_role_id = Integer(primary=True)
|
||||||
|
channelid = Integer()
|
||||||
|
roleid = Integer()
|
||||||
|
|
||||||
|
created_at = Timestamp()
|
||||||
@@ -5,7 +5,7 @@ import datetime as dt
|
|||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands as cmds
|
from discord.ext import commands as cmds
|
||||||
from discord import app_commands as appcmds
|
from discord import AllowedMentions, app_commands as appcmds
|
||||||
|
|
||||||
from data import Condition
|
from data import Condition
|
||||||
from meta import LionBot, LionCog, LionContext
|
from meta import LionBot, LionCog, LionContext
|
||||||
@@ -654,7 +654,7 @@ class VoiceTrackerCog(LionCog):
|
|||||||
|
|
||||||
# ----- Commands -----
|
# ----- Commands -----
|
||||||
@cmds.hybrid_command(
|
@cmds.hybrid_command(
|
||||||
name=_p('cmd:now', "now"),
|
name="tag",
|
||||||
description=_p(
|
description=_p(
|
||||||
'cmd:now|desc',
|
'cmd:now|desc',
|
||||||
"Describe what you are working on, or see what your friends are working on!"
|
"Describe what you are working on, or see what your friends are working on!"
|
||||||
@@ -668,7 +668,7 @@ class VoiceTrackerCog(LionCog):
|
|||||||
@appcmds.describe(
|
@appcmds.describe(
|
||||||
tag=_p(
|
tag=_p(
|
||||||
'cmd:now|param:tag|desc',
|
'cmd:now|param:tag|desc',
|
||||||
"Describe what you are working on in 10 characters or less!"
|
"Describe what you are working!"
|
||||||
),
|
),
|
||||||
user=_p(
|
user=_p(
|
||||||
'cmd:now|param:user|desc',
|
'cmd:now|param:user|desc',
|
||||||
@@ -681,17 +681,15 @@ class VoiceTrackerCog(LionCog):
|
|||||||
)
|
)
|
||||||
@appcmds.guild_only
|
@appcmds.guild_only
|
||||||
async def now_cmd(self, ctx: LionContext,
|
async def now_cmd(self, ctx: LionContext,
|
||||||
tag: Optional[appcmds.Range[str, 0, 10]] = None,
|
tag: Optional[str] = None,
|
||||||
|
*,
|
||||||
user: Optional[discord.Member] = None,
|
user: Optional[discord.Member] = None,
|
||||||
clear: Optional[bool] = None
|
clear: Optional[bool] = None
|
||||||
):
|
):
|
||||||
if not ctx.guild:
|
if not ctx.guild:
|
||||||
return
|
return
|
||||||
if not ctx.interaction:
|
|
||||||
return
|
|
||||||
t = self.bot.translator.t
|
t = self.bot.translator.t
|
||||||
|
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
|
||||||
is_moderator = await moderator_ctxward(ctx)
|
is_moderator = await moderator_ctxward(ctx)
|
||||||
target = user if user is not None else ctx.author
|
target = user if user is not None else ctx.author
|
||||||
session = self.get_session(ctx.guild.id, target.id, create=False)
|
session = self.get_session(ctx.guild.id, target.id, create=False)
|
||||||
@@ -715,7 +713,7 @@ class VoiceTrackerCog(LionCog):
|
|||||||
"{mention} has no running session!"
|
"{mention} has no running session!"
|
||||||
)).format(mention=target.mention)
|
)).format(mention=target.mention)
|
||||||
)
|
)
|
||||||
await ctx.interaction.edit_original_response(embed=error)
|
await ctx.reply(embed=error)
|
||||||
return
|
return
|
||||||
|
|
||||||
if clear:
|
if clear:
|
||||||
@@ -723,87 +721,27 @@ class VoiceTrackerCog(LionCog):
|
|||||||
if target == ctx.author:
|
if target == ctx.author:
|
||||||
# Clear the author's tag
|
# Clear the author's tag
|
||||||
await session.set_tag(None)
|
await session.set_tag(None)
|
||||||
ack = discord.Embed(
|
ack = "Cleared your current task!"
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:self|mode:clear|success|title',
|
|
||||||
"Session Tag Cleared"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:self|mode:clear|success|desc',
|
|
||||||
"Successfully unset your session tag."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
elif not is_moderator:
|
elif not is_moderator:
|
||||||
# Trying to clear someone else's tag without being a moderator
|
# Trying to clear someone else's tag without being a moderator
|
||||||
ack = discord.Embed(
|
ack = "You need to be a moderator to set or clear someone else's task!"
|
||||||
colour=discord.Colour.brand_red(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:other|mode:clear|error:perms|title',
|
|
||||||
"You can't do that!"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:other|mode:clear|error:perms|desc',
|
|
||||||
"You need to be a moderator to set or clear someone else's session tag."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Clearing someone else's tag as a moderator
|
# Clearing someone else's tag as a moderator
|
||||||
await session.set_tag(None)
|
await session.set_tag(None)
|
||||||
ack = discord.Embed(
|
ack = f"Cleared {target}'s current task!"
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:other|mode:clear|success|title',
|
|
||||||
"Session Tag Cleared!"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:other|mode:clear|success|desc',
|
|
||||||
"Cleared {target}'s session tag."
|
|
||||||
)).format(target=target.mention)
|
|
||||||
)
|
|
||||||
elif tag:
|
elif tag:
|
||||||
# Tag setting mode
|
# Tag setting mode
|
||||||
if target == ctx.author:
|
if target == ctx.author:
|
||||||
# Set the author's tag
|
# Set the author's tag
|
||||||
await session.set_tag(tag)
|
await session.set_tag(tag)
|
||||||
ack = discord.Embed(
|
ack = f"Set your current task to `{tag}`, good luck! <:goodluck:1266447460146876497>"
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:self|mode:set|success|title',
|
|
||||||
"Session Tag Set!"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:self|mode:set|success|desc',
|
|
||||||
"You are now working on `{new_tag}`. Good luck!"
|
|
||||||
)).format(new_tag=tag)
|
|
||||||
)
|
|
||||||
elif not is_moderator:
|
elif not is_moderator:
|
||||||
# Trying the set someone else's tag without being a moderator
|
# Trying the set someone else's tag without being a moderator
|
||||||
ack = discord.Embed(
|
ack = "You need to be a moderator to set or clear someone else's task!"
|
||||||
colour=discord.Colour.brand_red(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:other|mode:set|error:perms|title',
|
|
||||||
"You can't do that!"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:other|mode:set|error:perms|desc',
|
|
||||||
"You need to be a moderator to set or clear someone else's session tag!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Setting someone else's tag as a moderator
|
# Setting someone else's tag as a moderator
|
||||||
await session.set_tag(tag)
|
await session.set_tag(tag)
|
||||||
ack = discord.Embed(
|
ack = f"Set {target}'s current task to `{tag}`"
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=t(_p(
|
|
||||||
'cmd:now|target:other|mode:set|success|title',
|
|
||||||
"Session Tag Set!"
|
|
||||||
)),
|
|
||||||
description=t(_p(
|
|
||||||
'cmd:now|target:other|mode:set|success|desc',
|
|
||||||
"Set {target}'s session tag to `{new_tag}`."
|
|
||||||
)).format(target=target.mention, new_tag=tag)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Display tag and voice time
|
# Display tag and voice time
|
||||||
if target == ctx.author:
|
if target == ctx.author:
|
||||||
@@ -815,14 +753,14 @@ class VoiceTrackerCog(LionCog):
|
|||||||
else:
|
else:
|
||||||
desc = t(_p(
|
desc = t(_p(
|
||||||
'cmd:now|target:self|mode:show_without_tag|desc',
|
'cmd:now|target:self|mode:show_without_tag|desc',
|
||||||
"You have been working in {channel} since {time}!\n\n"
|
"You have been working in {channel} since {time}! "
|
||||||
"Use `/now <tag>` to set what you are working on."
|
"Use `/now <tag>` to set what you are working on."
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
if session.tag:
|
if session.tag:
|
||||||
desc = t(_p(
|
desc = t(_p(
|
||||||
'cmd:now|target:other|mode:show_with_tag|desc',
|
'cmd:now|target:other|mode:show_with_tag|desc',
|
||||||
"{target} is current working in {channel}!\n"
|
"{target} is current working in {channel}! "
|
||||||
"They have been working on **{tag}** since {time}."
|
"They have been working on **{tag}** since {time}."
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
@@ -830,18 +768,13 @@ class VoiceTrackerCog(LionCog):
|
|||||||
'cmd:now|target:other|mode:show_without_tag|desc',
|
'cmd:now|target:other|mode:show_without_tag|desc',
|
||||||
"{target} has been working in {channel} since {time}!"
|
"{target} has been working in {channel} since {time}!"
|
||||||
))
|
))
|
||||||
desc = desc.format(
|
ack = desc.format(
|
||||||
tag=session.tag,
|
tag=session.tag,
|
||||||
channel=f"<#{session.state.channelid}>",
|
channel=f"<#{session.state.channelid}>",
|
||||||
time=discord.utils.format_dt(session.start_time, 't'),
|
time=discord.utils.format_dt(session.start_time, 'R'),
|
||||||
target=target.mention,
|
target=target.mention,
|
||||||
)
|
)
|
||||||
ack = discord.Embed(
|
await ctx.reply(ack, allowed_mentions=AllowedMentions.none())
|
||||||
colour=discord.Colour.orange(),
|
|
||||||
description=desc,
|
|
||||||
timestamp=utc_now()
|
|
||||||
)
|
|
||||||
await ctx.interaction.edit_original_response(embed=ack)
|
|
||||||
|
|
||||||
# ----- Configuration Commands -----
|
# ----- Configuration Commands -----
|
||||||
@LionCog.placeholder_group
|
@LionCog.placeholder_group
|
||||||
|
|||||||
9
src/twitch/__init__.py
Normal file
9
src/twitch/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from .cog import TwitchAuthCog
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
await bot.add_cog(TwitchAuthCog(bot))
|
||||||
|
|
||||||
50
src/twitch/authclient.py
Normal file
50
src/twitch/authclient.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Testing client for the twitch AuthServer.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
|
||||||
|
from meta.config import conf
|
||||||
|
|
||||||
|
|
||||||
|
URI = "http://localhost:3000/twiauth/confirm"
|
||||||
|
TARGET_SCOPE = [AuthScope.CHAT_EDIT, AuthScope.CHAT_READ]
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Load in client id and secret
|
||||||
|
twitch = await Twitch(conf.twitch['app_id'], conf.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, TARGET_SCOPE, url=URI)
|
||||||
|
url = auth.return_auth_url()
|
||||||
|
|
||||||
|
# Post url to user
|
||||||
|
print(url)
|
||||||
|
|
||||||
|
# Send listen request to server
|
||||||
|
# Wait for listen request
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect('http://localhost:3000/twiauth/listen') as ws:
|
||||||
|
await ws.send_json({'state': auth.state})
|
||||||
|
result = await ws.receive_json()
|
||||||
|
|
||||||
|
# Hopefully get back code, print the response
|
||||||
|
print(f"Recieved: {result}")
|
||||||
|
|
||||||
|
# Authorise with code and client details
|
||||||
|
tokens = await auth.authenticate(user_token=result['code'])
|
||||||
|
if tokens:
|
||||||
|
token, refresh = tokens
|
||||||
|
await twitch.set_user_authentication(token, TARGET_SCOPE, refresh)
|
||||||
|
print(f"Authorised!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
86
src/twitch/authserver.py
Normal file
86
src/twitch/authserver.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import asyncio
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
reqid: ContextVar[str] = ContextVar('reqid', default='ROOT')
|
||||||
|
|
||||||
|
|
||||||
|
class AuthServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.listeners = {}
|
||||||
|
|
||||||
|
async def handle_twitch_callback(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
args = request.query
|
||||||
|
if 'state' not in args:
|
||||||
|
raise web.HTTPBadRequest(text="No state provided.")
|
||||||
|
if args['state'] not in self.listeners:
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state.")
|
||||||
|
self.listeners[args['state']].set_result(dict(args))
|
||||||
|
return web.Response(text="Authorisation complete! You may now close this page and return to the application.")
|
||||||
|
|
||||||
|
async def handle_listen_request(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
_reqid = str(uuid.uuid1())
|
||||||
|
reqid.set(_reqid)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Received websocket listen connection: {request!r}")
|
||||||
|
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
# Get the listen request data
|
||||||
|
try:
|
||||||
|
listen_req = await ws.receive_json(timeout=60)
|
||||||
|
logger.info(f"[reqid: {_reqid}] Received websocket listen request: {request}")
|
||||||
|
if 'state' not in listen_req:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request is missing state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Listen request must include state string.")
|
||||||
|
elif listen_req['state'] in self.listeners:
|
||||||
|
logger.error(f"[reqid: {_reqid}] Websocket listen request with duplicate state, cancelling.")
|
||||||
|
raise web.HTTPBadRequest(text="Invalid state string.")
|
||||||
|
except ValueError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request could not be parsed to JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except TypeError:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Listen request was binary not JSON.")
|
||||||
|
raise web.HTTPBadRequest(text="Request must be a JSON formatted string.")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for listen request data.")
|
||||||
|
raise web.HTTPRequestTimeout(text="Request must be a JSON formatted string.")
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"[reqid: {_reqid}] Unknown exception.")
|
||||||
|
raise web.HTTPInternalServerError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
fut = self.listeners[listen_req['state']] = asyncio.Future()
|
||||||
|
result = await asyncio.wait_for(fut, timeout=120)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.info(f"[reqid: {_reqid}] Timed out waiting for auth callback from Twitch, closing.")
|
||||||
|
raise web.HTTPGatewayTimeout(text="Did not receive an authorisation code from Twitch in time.")
|
||||||
|
finally:
|
||||||
|
self.listeners.pop(listen_req['state'], None)
|
||||||
|
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Responding with auth result {result}.")
|
||||||
|
await ws.send_json(result)
|
||||||
|
await ws.close()
|
||||||
|
logger.debug(f"[reqid: {_reqid}] Request completed handling.")
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
app = web.Application()
|
||||||
|
server = AuthServer()
|
||||||
|
app.router.add_get("/twiauth/confirm", server.handle_twitch_callback)
|
||||||
|
app.router.add_get("/twiauth/listen", server.handle_listen_request)
|
||||||
|
|
||||||
|
logger.info("App setup and configured. Starting now.")
|
||||||
|
web.run_app(app, port=int(argv[1]) if len(argv) > 1 else 8080)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
main(sys.argv)
|
||||||
84
src/twitch/cog.py
Normal file
84
src/twitch/cog.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
from twitchAPI.oauth import UserAuthenticator
|
||||||
|
from twitchAPI.twitch import AuthType, Twitch
|
||||||
|
from twitchAPI.type import AuthScope
|
||||||
|
import twitchio
|
||||||
|
from twitchio.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
from data.queries import ORDER
|
||||||
|
from meta import LionCog, LionBot, CrocBot
|
||||||
|
from meta.LionContext import LionContext
|
||||||
|
from twitch.userflow import UserAuthFlow
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from . import logger
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthCog(LionCog):
|
||||||
|
DEFAULT_SCOPES = []
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = bot.db.load_registry(TwitchAuthData())
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
# ----- Auth API -----
|
||||||
|
|
||||||
|
async def fetch_client_for(self, userid: int):
|
||||||
|
...
|
||||||
|
|
||||||
|
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the given userid is authorised.
|
||||||
|
If 'scopes' is given, will also check the user has all of the given scopes.
|
||||||
|
"""
|
||||||
|
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||||
|
if authrow:
|
||||||
|
if scopes:
|
||||||
|
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
has_auth = set(map(str, scopes)).issubset(has_scopes)
|
||||||
|
else:
|
||||||
|
has_auth = True
|
||||||
|
else:
|
||||||
|
has_auth = False
|
||||||
|
return has_auth
|
||||||
|
|
||||||
|
async def start_auth_for(self, userid: str, scopes: list[AuthScope] = []):
|
||||||
|
"""
|
||||||
|
Start the user authentication flow for the given userid.
|
||||||
|
Will request the given scopes along with the default ones and any existing scopes.
|
||||||
|
"""
|
||||||
|
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||||
|
existing = map(AuthScope, existing_strs)
|
||||||
|
to_request = set(existing).union(scopes)
|
||||||
|
return await self.start_auth(to_request)
|
||||||
|
|
||||||
|
async def start_auth(self, scopes = []):
|
||||||
|
# TODO: Work out a way to just clone the current twitch object
|
||||||
|
# Or can we otherwise build UserAuthenticator without app auth?
|
||||||
|
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||||
|
auth = UserAuthenticator(twitch, scopes, url=self.bot.config.twitchauth['callback_uri'])
|
||||||
|
flow = UserAuthFlow(self.data, auth, self.bot.config.twitchauth['ws_url'])
|
||||||
|
await flow.setup()
|
||||||
|
|
||||||
|
return flow
|
||||||
|
|
||||||
|
# ----- Commands -----
|
||||||
|
@cmds.hybrid_command(name='auth')
|
||||||
|
async def cmd_auth(self, ctx: LionContext):
|
||||||
|
if ctx.interaction:
|
||||||
|
await ctx.interaction.response.defer(ephemeral=True)
|
||||||
|
flow = await self.start_auth()
|
||||||
|
await ctx.reply(flow.auth.return_auth_url())
|
||||||
|
await flow.run()
|
||||||
|
await ctx.reply("Authentication Complete!")
|
||||||
79
src/twitch/data.py
Normal file
79
src/twitch/data.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import Integer, String, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class TwitchAuthData(Registry):
|
||||||
|
class UserAuthRow(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
access_token TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
obtained_at TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'twitch_user_auth'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = Integer(primary=True)
|
||||||
|
access_token = String()
|
||||||
|
refresh_token = String()
|
||||||
|
expires_at = Timestamp()
|
||||||
|
obtained_at = Timestamp()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def update_user_auth(
|
||||||
|
cls, userid: str, token: str, refresh: str,
|
||||||
|
expires_at: dt.datetime, obtained_at: dt.datetime,
|
||||||
|
scopes: list[str]
|
||||||
|
):
|
||||||
|
if cls._connector is None:
|
||||||
|
raise ValueError("Attempting to use uninitialised Registry.")
|
||||||
|
async with cls._connector.connection() as conn:
|
||||||
|
cls._connector.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
# Clear row for this userid
|
||||||
|
await cls.table.delete_where(userid=userid)
|
||||||
|
|
||||||
|
# Insert new user row
|
||||||
|
row = await cls.create(
|
||||||
|
userid=userid,
|
||||||
|
access_token=token,
|
||||||
|
refresh_token=refresh,
|
||||||
|
expires_at=expires_at,
|
||||||
|
obtained_at=obtained_at
|
||||||
|
)
|
||||||
|
# Insert new scope rows
|
||||||
|
if scopes:
|
||||||
|
await TwitchAuthData.user_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in scopes)
|
||||||
|
)
|
||||||
|
return row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_scopes_for(cls, userid: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get a list of scopes stored for the given user.
|
||||||
|
Will return an empty list if the user is not authenticated.
|
||||||
|
"""
|
||||||
|
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||||
|
|
||||||
|
return [row.scope for row in rows] if rows else []
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE twitch_user_scopes(
|
||||||
|
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
scope TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||||
|
"""
|
||||||
|
user_scopes = Table('twitch_user_scopes')
|
||||||
0
src/twitch/lib.py
Normal file
0
src/twitch/lib.py
Normal file
88
src/twitch/userflow.py
Normal file
88
src/twitch/userflow.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from twitchAPI.twitch import Twitch
|
||||||
|
from twitchAPI.oauth import UserAuthenticator, validate_token
|
||||||
|
from twitchAPI.type import AuthType
|
||||||
|
from twitchio.client import asyncio
|
||||||
|
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from .data import TwitchAuthData
|
||||||
|
from . import logger
|
||||||
|
|
||||||
|
class UserAuthFlow:
|
||||||
|
auth: UserAuthenticator
|
||||||
|
data: TwitchAuthData
|
||||||
|
auth_ws: str
|
||||||
|
|
||||||
|
def __init__(self, data, auth, auth_ws):
|
||||||
|
self.auth = auth
|
||||||
|
self.data = data
|
||||||
|
self.auth_ws = auth_ws
|
||||||
|
|
||||||
|
self._setup_done = asyncio.Event()
|
||||||
|
self._comm_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""
|
||||||
|
Establishes websocket connection to the AuthServer,
|
||||||
|
and requests listening for the given state.
|
||||||
|
Propagates any exceptions that occur during connection setup.
|
||||||
|
"""
|
||||||
|
if self._setup_done.is_set():
|
||||||
|
raise ValueError("UserAuthFlow is already set up.")
|
||||||
|
self._comm_task = asyncio.create_task(self._communicate(), name='UserAuthFlow-communicate')
|
||||||
|
await self._setup_done.wait()
|
||||||
|
if self._comm_task.done() and (exc := self._comm_task.exception()):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
async def _communicate(self):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.ws_connect(self.auth_ws) as ws:
|
||||||
|
await ws.send_json({'state': self.auth.state})
|
||||||
|
self._setup_done.set()
|
||||||
|
return await ws.receive_json()
|
||||||
|
|
||||||
|
async def run(self) -> TwitchAuthData.UserAuthRow:
|
||||||
|
if not self._setup_done.is_set():
|
||||||
|
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||||
|
if self._comm_task is None:
|
||||||
|
raise ValueError("UserAuthFlow running with no comm task! This should be impossible.")
|
||||||
|
|
||||||
|
result = await self._comm_task
|
||||||
|
if result.get('error', None):
|
||||||
|
# TODO Custom auth errors
|
||||||
|
# This is only documented to occur when the user denies the auth
|
||||||
|
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||||
|
|
||||||
|
if result.get('state', None) != self.auth.state:
|
||||||
|
# This should never happen unless the authserver has its wires crossed somehow,
|
||||||
|
# or the connection has been tampered with.
|
||||||
|
# TODO: Consider terminating for safety in this case? Or at least refusing more auth requests.
|
||||||
|
logger.critical(
|
||||||
|
f"Received {result} while waiting for state {self.auth.state!r}. SOMETHING IS WRONG."
|
||||||
|
)
|
||||||
|
raise SafeCancellation(
|
||||||
|
"Could not complete authentication! Invalid server response."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now assume result has a valid code
|
||||||
|
# Exchange code for an auth token and a refresh token
|
||||||
|
# Ignore type here, authenticate returns None if a callback function has been given.
|
||||||
|
token, refresh = await self.auth.authenticate(user_token=result['code']) # type: ignore
|
||||||
|
|
||||||
|
# Fetch the associated userid and basic info
|
||||||
|
v_result = await validate_token(token)
|
||||||
|
userid = v_result['user_id']
|
||||||
|
expiry = utc_now() + dt.timedelta(seconds=v_result['expires_in'])
|
||||||
|
|
||||||
|
# Save auth data
|
||||||
|
return await self.data.UserAuthRow.update_user_auth(
|
||||||
|
userid=userid, token=token, refresh=refresh,
|
||||||
|
expires_at=expiry, obtained_at=utc_now(),
|
||||||
|
scopes=[scope.value for scope in self.auth.scopes]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user