Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 096831ff38 | |||
| 6f6f1c8711 | |||
| f7156a42ce | |||
| 0eaa73b32a | |||
| d4fa04e696 | |||
| 6988daccf2 | |||
| 75565d2b3c | |||
| e4f8a043a8 | |||
| 7cbb6adcb8 | |||
| 010d52e72e | |||
| c5e9cb1488 | |||
| d1114f1a06 | |||
| 2d87783c3e | |||
| d1c5c4a0af | |||
| e5c788dfae | |||
| 2b650c220b | |||
| ed493c3988 | |||
| f45813195d | |||
| 3450f4a4b2 | |||
| d4870740a2 | |||
| 8991b1a641 | |||
| 79645177bd | |||
| 9b3b7265d3 | |||
| 3c0d527501 | |||
| 997804c6bf | |||
| 2cdd084bbe | |||
| 72d52b6014 | |||
| 92fee23afa | |||
| 83a63e8a6e | |||
| 63152f3475 | |||
| 81e25e7efc | |||
| ce07f7ae73 | |||
| d158aed257 | |||
| 47a52d9600 | |||
| fc459ac0dd | |||
| 45b57b4eca | |||
| 22b99717db | |||
| 2810365588 | |||
| e9946a9814 | |||
| 8f6fdf3381 | |||
| 9d0d19d046 | |||
| a7eb8d0f09 | |||
| 9c738ecb91 | |||
| 9c9107bf9d | |||
| f2c449d2e0 | |||
| 53366c0333 | |||
| 66f7680482 | |||
| 37f25f10ef | |||
| 87488eaf99 | |||
| bc073363b9 |
6
.gitmodules
vendored
6
.gitmodules
vendored
@@ -1,12 +1,12 @@
|
||||
[submodule "bot/gui"]
|
||||
path = src/gui
|
||||
url = https://github.com/StudyLions/StudyLion-Plugin-GUI.git
|
||||
url = git@github.com:Intery/CafeHelper-GUI.git
|
||||
[submodule "skins"]
|
||||
path = skins
|
||||
url = https://github.com/Intery/pillow-skins.git
|
||||
url = git@github.com:Intery/CafeHelper-Skins.git
|
||||
[submodule "src/modules/voicefix"]
|
||||
path = src/modules/voicefix
|
||||
url = https://github.com/Intery/StudyLion-voicefix.git
|
||||
url = git@github.com:Intery/StudyLion-voicefix.git
|
||||
[submodule "src/modules/streamalerts"]
|
||||
path = src/modules/streamalerts
|
||||
url = https://github.com/Intery/StudyLion-streamalerts.git
|
||||
|
||||
@@ -1454,6 +1454,7 @@ CREATE TABLE shoutouts(
|
||||
CREATE TABLE counters(
|
||||
counterid SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
category TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||
@@ -1464,6 +1465,7 @@ CREATE TABLE counter_log(
|
||||
userid INTEGER NOT NULL,
|
||||
value INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
details TEXT,
|
||||
context_str TEXT
|
||||
);
|
||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||
@@ -1484,6 +1486,81 @@ CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Voice Roles {{{
|
||||
CREATE TABLE voice_roles(
|
||||
voice_role_id SERIAL PRIMARY KEY,
|
||||
channelid BIGINT NOT NULL,
|
||||
roleid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX voice_role_channels on voice_roles (channelid);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- User and Community Profiles {{{
|
||||
DROP TABLE IF EXISTS community_members;
|
||||
DROP TABLE IF EXISTS communities_twitch;
|
||||
DROP TABLE IF EXISTS communities_discord;
|
||||
DROP TABLE IF EXISTS communities;
|
||||
DROP TABLE IF EXISTS profiles_twitch;
|
||||
DROP TABLE IF EXISTS profiles_discord;
|
||||
DROP TABLE IF EXISTS user_profiles;
|
||||
|
||||
|
||||
CREATE TABLE user_profiles(
|
||||
profileid SERIAL PRIMARY KEY,
|
||||
nickname TEXT,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE profiles_discord(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||
|
||||
CREATE TABLE profiles_twitch(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||
|
||||
|
||||
CREATE TABLE communities(
|
||||
communityid SERIAL PRIMARY KEY,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE communities_discord(
|
||||
guildid BIGINT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX communities_discord_communityid ON communities_discord (communityid);
|
||||
|
||||
CREATE TABLE communities_twitch(
|
||||
channelid TEXT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX communities_twitch_communityid ON communities_twitch (communityid);
|
||||
|
||||
CREATE TABLE community_members(
|
||||
memberid SERIAL PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||
-- }}}
|
||||
|
||||
-- Twitch User Auth {{{
|
||||
CREATE TABLE twitch_user_auth(
|
||||
@@ -1494,7 +1571,6 @@ CREATE TABLE twitch_user_auth(
|
||||
obtained_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE twitch_user_scopes(
|
||||
userid TEXT REFERENCES twitch_user_auth (userid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
scope TEXT
|
||||
|
||||
2
skins
2
skins
Submodule skins updated: d3d6a28bc9...686857321e
@@ -98,6 +98,7 @@ async def main():
|
||||
config=conf,
|
||||
initial_extensions=[
|
||||
'utils', 'core', 'analytics',
|
||||
'twitch',
|
||||
'modules',
|
||||
'babel',
|
||||
'tracking.voice', 'tracking.text',
|
||||
|
||||
2
src/gui
2
src/gui
Submodule src/gui updated: c1bcb05c25...62d2484914
@@ -1,9 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import logging
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
from twitchio.ext import pubsub
|
||||
from twitchio.ext.commands.core import itertools
|
||||
|
||||
from data import Database
|
||||
|
||||
@@ -23,5 +26,51 @@ class CrocBot(commands.Bot):
|
||||
self.data = data
|
||||
self.pubsub = pubsub.PubSubPool(self)
|
||||
|
||||
self._member_cache = defaultdict(dict)
|
||||
|
||||
async def event_ready(self):
|
||||
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||
|
||||
async def event_join(self, channel: twitchio.Channel, user: twitchio.User):
|
||||
self._member_cache[channel.name][user.name] = user
|
||||
|
||||
async def event_message(self, message: twitchio.Message):
|
||||
if message.channel and message.author:
|
||||
self._member_cache[message.channel.name][message.author.name] = message.author
|
||||
await self.handle_commands(message)
|
||||
|
||||
async def seek_user(self, userstr: str, matching=True, fuzzy=True):
|
||||
if userstr.startswith('@'):
|
||||
matching = False
|
||||
userstr = userstr.strip('@ ')
|
||||
|
||||
result = None
|
||||
if matching and len(userstr) >= 3:
|
||||
lowered = userstr.lower()
|
||||
full_matches = []
|
||||
for user in itertools.chain(*(cmems.values() for cmems in self._member_cache.values())):
|
||||
matchstr = user.name.lower()
|
||||
print(matchstr)
|
||||
if matchstr.startswith(lowered):
|
||||
result = user
|
||||
break
|
||||
if lowered in matchstr:
|
||||
full_matches.append(user)
|
||||
if result is None and full_matches:
|
||||
result = full_matches[0]
|
||||
print(result)
|
||||
|
||||
if result is None:
|
||||
lookup = userstr
|
||||
elif result.id is None:
|
||||
lookup = result.name
|
||||
else:
|
||||
lookup = None
|
||||
|
||||
if lookup:
|
||||
found = await self.fetch_users(names=[lookup])
|
||||
if found:
|
||||
result = found[0]
|
||||
|
||||
# No matches found
|
||||
return result
|
||||
|
||||
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from meta.CrocBot import CrocBot
|
||||
from core.cog import CoreCog
|
||||
from core.config import ConfigCog
|
||||
from twitch.cog import TwitchAuthCog
|
||||
from tracking.voice.cog import VoiceTrackerCog
|
||||
from tracking.text.cog import TextTrackerCog
|
||||
from modules.config.cog import GuildConfigCog
|
||||
@@ -49,6 +50,7 @@ if TYPE_CHECKING:
|
||||
from modules.topgg.cog import TopggCog
|
||||
from modules.user_config.cog import UserConfigCog
|
||||
from modules.video_channels.cog import VideoCog
|
||||
from modules.profiles.cog import ProfileCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -142,6 +144,10 @@ class LionBot(Bot):
|
||||
# To make the type checker happy about fetching cogs by name
|
||||
# TODO: Move this to stubs at some point
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['ProfileCog']) -> 'ProfileCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||
...
|
||||
@@ -154,6 +160,10 @@ class LionBot(Bot):
|
||||
def get_cog(self, name: Literal['VoiceTrackerCog']) -> 'VoiceTrackerCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['TwitchAuthCog']) -> 'TwitchAuthCog':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_cog(self, name: Literal['TextTrackerCog']) -> 'TextTrackerCog':
|
||||
...
|
||||
|
||||
@@ -22,6 +22,7 @@ class LionCog(Cog):
|
||||
cls._placeholder_groups_ = set()
|
||||
cls._twitch_cmds_ = {}
|
||||
cls._twitch_events_ = {}
|
||||
cls._twitch_events_loaded_ = set()
|
||||
|
||||
for base in reversed(cls.__mro__):
|
||||
for elem, value in base.__dict__.items():
|
||||
@@ -47,6 +48,27 @@ class LionCog(Cog):
|
||||
|
||||
return await super()._inject(bot, *args, *kwargs)
|
||||
|
||||
def add_twitch_command(self, bot: Bot, command: Command):
|
||||
"""
|
||||
Dynamically register a command with the given bot.
|
||||
|
||||
The command will be deregistered on cog unload.
|
||||
"""
|
||||
# Remove any conflicting commands
|
||||
if cmd := bot.get_command(command.name):
|
||||
bot.remove_command(cmd.name)
|
||||
self._twitch_cmds_.pop(command.name, None)
|
||||
|
||||
try:
|
||||
self._twitch_cmds_[command.name] = command
|
||||
command._instance = self
|
||||
command.cog = self
|
||||
bot.add_command(command)
|
||||
except Exception:
|
||||
# Ensure the command doesn't die in the internal command cache
|
||||
self._twitch_cmds_.pop(command.name, None)
|
||||
raise
|
||||
|
||||
def _load_twitch_methods(self, bot: Bot):
|
||||
for name, command in self._twitch_cmds_.items():
|
||||
command._instance = self
|
||||
|
||||
@@ -2,6 +2,7 @@ this_package = 'modules'
|
||||
|
||||
active_discord = [
|
||||
'.sysadmin',
|
||||
'.profiles',
|
||||
'.config',
|
||||
'.user_config',
|
||||
'.skins',
|
||||
@@ -30,6 +31,11 @@ active_discord = [
|
||||
'.nowdoing',
|
||||
'.shoutouts',
|
||||
'.tagstrings',
|
||||
'.voiceroles',
|
||||
'.hyperfocus',
|
||||
'.twreminders',
|
||||
'.time',
|
||||
'.checkin',
|
||||
]
|
||||
|
||||
async def setup(bot):
|
||||
|
||||
8
src/modules/checkin/__init__.py
Normal file
8
src/modules/checkin/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import CheckinCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(CheckinCog(bot))
|
||||
154
src/modules/checkin/cog.py
Normal file
154
src/modules/checkin/cog.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
import datetime as dt
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
import discord
|
||||
import twitchAPI
|
||||
from twitchAPI.object.eventsub import ChannelPointsCustomRewardRedemptionData
|
||||
from twitchAPI.eventsub.websocket import EventSubWebsocket
|
||||
from twitchAPI.type import AuthScope
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
|
||||
|
||||
class CheckinCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
|
||||
self.listeners = []
|
||||
self.eswebsockets = {}
|
||||
|
||||
async def cog_load(self):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
|
||||
check_in_channel_id = self.bot.config.croccy['check_in_channel'].strip()
|
||||
await self.attach_checkin_channel(check_in_channel_id)
|
||||
|
||||
async def cog_unload(self):
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
async def fetch_eventsub_for(self, channelid):
|
||||
if (eventsub := self.eswebsockets.get(channelid)) is None:
|
||||
authcog = self.bot.get_cog('TwitchAuthCog')
|
||||
if not await authcog.check_auth(channelid, scopes=[AuthScope.CHANNEL_READ_REDEMPTIONS]):
|
||||
logger.error(
|
||||
f"Insufficient auth to login to registered check-in channelid {channelid}"
|
||||
)
|
||||
else:
|
||||
twitch = await authcog.fetch_client_for(channelid)
|
||||
eventsub = EventSubWebsocket(twitch)
|
||||
eventsub.start()
|
||||
self.eswebsockets[channelid] = eventsub
|
||||
return eventsub
|
||||
|
||||
async def attach_checkin_channel(self, channel):
|
||||
# Register a listener for the given channel (given as a string id)
|
||||
eventsub = await self.fetch_eventsub_for(channel)
|
||||
if eventsub:
|
||||
await eventsub.listen_channel_points_custom_reward_redemption_add(channel, self.handle_redeem)
|
||||
logger.info(f"Attached check-in listener to registered channel {channel}")
|
||||
else:
|
||||
logger.error(f"Could not attach checkin listener to registered channel {channel}")
|
||||
|
||||
async def handle_redeem(self, data: ChannelPointsCustomRewardRedemptionData):
|
||||
# Check if the redeem is one of the 'checkin' or 'quiet checkin' redeems.
|
||||
title = data.event.reward.title.lower()
|
||||
# TODO: Redeem ID based registration (configured)
|
||||
seeking = ('check in', 'quiet hello')
|
||||
if title in seeking:
|
||||
quiet = seeking.index(title)
|
||||
await self.do_checkin(
|
||||
data.event.broadcaster_user_id,
|
||||
data.event.broadcaster_user_login,
|
||||
data.event.user_id,
|
||||
data.event.user_name,
|
||||
quiet,
|
||||
data.event.redeemed_at
|
||||
)
|
||||
|
||||
async def do_checkin(self, channel, channel_name, user, user_name, quiet, redeemed_at):
|
||||
logger.info(
|
||||
f"Starting checkin process for {channel_name=}, {user_name=}, {quiet=}, {redeemed_at=}"
|
||||
)
|
||||
checkin_counter_name = '_checkin'
|
||||
first_counter_name = '_first'
|
||||
second_counter_name = '_second'
|
||||
third_counter_name = '_third'
|
||||
|
||||
counters = self.bot.get_cog('CounterCog')
|
||||
if not counters:
|
||||
raise ValueError("Check-in running without counters cog loaded!")
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
if not profiles:
|
||||
raise ValueError("Check-in running without profile cog loaded!")
|
||||
|
||||
# TODO: Relies on profile implementation detail
|
||||
profile = await profiles.fetch_profile_twitch(discord.Object(id=user))
|
||||
|
||||
stream_start = await self.get_stream_start(channel)
|
||||
# Stream has to be running for this to do anything
|
||||
if stream_start is not None:
|
||||
# Get all check-in redeems since the start of stream.
|
||||
check_in_counter = await counters.fetch_counter(checkin_counter_name)
|
||||
entries = await counters.data.CounterEntry.table.select_where(
|
||||
counters.data.CounterEntry.created_at >= stream_start,
|
||||
counterid=check_in_counter.counterid,
|
||||
)
|
||||
position = len(entries) + 1
|
||||
if profile.profileid not in (e['userid'] for e in entries):
|
||||
# User has not already checked in!
|
||||
# Check them in
|
||||
# TODO: May be worth setting custom counter time
|
||||
await counters.add_to_counter(
|
||||
counter=check_in_counter.name,
|
||||
userid=profile.profileid,
|
||||
value=1,
|
||||
)
|
||||
checkin_total = await counters.personal_total(checkin_counter_name, profile.profileid)
|
||||
|
||||
# If they deserve a first, give them that
|
||||
position_total = None
|
||||
if position <= 3:
|
||||
counter_name = (first_counter_name, second_counter_name, third_counter_name)[position-1]
|
||||
await counters.add_to_counter(
|
||||
counter=counter_name,
|
||||
userid=profile.profileid,
|
||||
value=1,
|
||||
)
|
||||
position_total = await counters.personal_total(counter_name, profile.profileid)
|
||||
|
||||
if not quiet:
|
||||
name = user_name
|
||||
if position == 1:
|
||||
message = f"Welcome in and congrats on first check-in {name}! You have been first {position_total}/{checkin_total} times!"
|
||||
else:
|
||||
# TODO: Randomised replies
|
||||
# TODO: Maybe different messages for lower positions or earlier times but not explicitly giving numbers?
|
||||
# Need to update this for stream calcs anyway.
|
||||
message = f"Welcome in {name}! You have checked in {checkin_total} times! Let's have a productive time together~"
|
||||
|
||||
# Now get the channel and post
|
||||
channel = self.crocbot.get_channel(channel_name)
|
||||
if not channel:
|
||||
logger.error(
|
||||
f"Channel {channel_name} is not in cache. Cannot send checkin reply."
|
||||
)
|
||||
else:
|
||||
await channel.send(message)
|
||||
|
||||
async def get_stream_start(self, channelid: str | int) -> Optional[datetime]:
|
||||
future = asyncio.run_coroutine_threadsafe(self._get_stream_start(channelid), self._loop)
|
||||
return future.result()
|
||||
|
||||
async def _get_stream_start(self, channelid: str | int) -> Optional[datetime]:
|
||||
streams = await self.crocbot.fetch_streams(user_ids=[int(channelid)])
|
||||
if streams:
|
||||
return streams[0].started_at
|
||||
@@ -3,17 +3,24 @@ from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
from data.base import RawExpr
|
||||
from data.columns import Column
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import LionCog, LionBot, CrocBot
|
||||
from utils.lib import utc_now
|
||||
from meta import LionCog, LionBot, CrocBot, LionContext
|
||||
from modules.profiles.community import Community
|
||||
from modules.profiles.profile import UserProfile
|
||||
from utils.lib import utc_now, paginate_list, pager
|
||||
from . import logger
|
||||
from .data import CounterData
|
||||
from .graphics.weekly import counter_weekly_card, counter_monthly_card
|
||||
|
||||
|
||||
class PERIOD(Enum):
|
||||
@@ -25,6 +32,98 @@ class PERIOD(Enum):
|
||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||
|
||||
|
||||
class ORIGIN(Enum):
|
||||
DISCORD = 'discord'
|
||||
TWITCH = 'twitch'
|
||||
|
||||
|
||||
def counter_cmd_factory(
|
||||
counter: str,
|
||||
response: str,
|
||||
default_period: Optional[PERIOD] = PERIOD.STREAM,
|
||||
context: Optional[str] = None
|
||||
):
|
||||
context = context or f"cmd: {counter}"
|
||||
async def counter_cmd(
|
||||
cog,
|
||||
ctx: commands.Context | LionContext,
|
||||
origin: ORIGIN,
|
||||
author: UserProfile,
|
||||
community: Community,
|
||||
args: Optional[str]
|
||||
):
|
||||
userid = author.profileid
|
||||
period, start_time = await cog.parse_period(community, '', default=default_period)
|
||||
|
||||
args = (args or '').strip(" ")
|
||||
splits = args.split(maxsplit=1)
|
||||
splits = [split.strip() for split in splits if split]
|
||||
|
||||
details = None
|
||||
amount = 1
|
||||
|
||||
if splits:
|
||||
if splits[0].isdigit() or (splits[0].startswith('-') and splits[0][1:].isdigit()):
|
||||
amount = int(splits[0])
|
||||
splits = splits[1:]
|
||||
if splits:
|
||||
details = ' '.join(splits)
|
||||
|
||||
await cog.add_to_counter(
|
||||
counter, userid, amount,
|
||||
context=context,
|
||||
details=details
|
||||
)
|
||||
lb = await cog.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(
|
||||
response.format(
|
||||
total=total,
|
||||
period=period,
|
||||
period_name=period.value[0],
|
||||
detailsorname=details or counter,
|
||||
user_total=user_total,
|
||||
)
|
||||
)
|
||||
|
||||
async def lb_cmd(
|
||||
cog,
|
||||
ctx: commands.Context | LionContext,
|
||||
origin: ORIGIN,
|
||||
author: UserProfile,
|
||||
community: Community,
|
||||
args: Optional[str]
|
||||
):
|
||||
await cog.show_lb(ctx, counter, args, author, community, origin)
|
||||
|
||||
async def undo_cmd(
|
||||
cog,
|
||||
ctx: commands.Context | LionContext,
|
||||
origin: ORIGIN,
|
||||
author: UserProfile,
|
||||
community: Community,
|
||||
args: Optional[str]
|
||||
):
|
||||
userid = author.profileid
|
||||
_counter = await cog.fetch_counter(counter)
|
||||
query = cog.data.CounterEntry.fetch_where(
|
||||
counterid=_counter.counterid,
|
||||
userid=userid,
|
||||
)
|
||||
query.order_by('created_at', direction=ORDER.DESC)
|
||||
query.limit(1)
|
||||
results = await query
|
||||
if not results:
|
||||
await ctx.reply("Nothing to delete!")
|
||||
else:
|
||||
row = results[0]
|
||||
await row.delete()
|
||||
await ctx.reply("Undo successful!")
|
||||
|
||||
return (counter_cmd, lb_cmd, undo_cmd)
|
||||
|
||||
|
||||
class CounterCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
@@ -40,12 +139,73 @@ class CounterCog(LionCog):
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
|
||||
await self.data.init()
|
||||
|
||||
await self.load_counter_commands()
|
||||
await self.load_counters()
|
||||
self.loaded.set()
|
||||
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
profiles.add_profile_migrator(self.migrate_profiles, name='counters')
|
||||
|
||||
async def cog_unload(self):
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
async def load_counter_commands(self):
|
||||
rows = await self.data.CounterCommand.fetch_where()
|
||||
for row in rows:
|
||||
counter = await self.data.Counter.fetch(row.counterid)
|
||||
counter_cb, lb_cb, undo_cb = counter_cmd_factory(
|
||||
counter.name,
|
||||
row.response
|
||||
)
|
||||
twitch_cmds = []
|
||||
disc_cmds = []
|
||||
twitch_cmds.append(
|
||||
commands.command(
|
||||
name=row.name
|
||||
)(self.twitch_callback(counter_cb))
|
||||
)
|
||||
disc_cmds.append(
|
||||
cmds.hybrid_command(
|
||||
name=row.name,
|
||||
with_app_command=False,
|
||||
)(self.discord_callback(counter_cb))
|
||||
)
|
||||
|
||||
if row.lbname:
|
||||
twitch_cmds.append(
|
||||
commands.command(
|
||||
name=row.lbname
|
||||
)(self.twitch_callback(lb_cb))
|
||||
)
|
||||
disc_cmds.append(
|
||||
cmds.hybrid_command(
|
||||
name=row.lbname,
|
||||
with_app_command=False,
|
||||
)(self.discord_callback(lb_cb))
|
||||
)
|
||||
if row.undoname:
|
||||
twitch_cmds.append(
|
||||
commands.command(
|
||||
name=row.undoname,
|
||||
)(self.twitch_callback(undo_cb))
|
||||
)
|
||||
disc_cmds.append(
|
||||
cmds.hybrid_command(
|
||||
name=row.undoname,
|
||||
with_app_command=False,
|
||||
)(self.discord_callback(undo_cb))
|
||||
)
|
||||
|
||||
for cmd in twitch_cmds:
|
||||
self.add_twitch_command(self.crocbot, cmd)
|
||||
for cmd in disc_cmds:
|
||||
# cmd.cog = self
|
||||
self.bot.add_command(cmd)
|
||||
print(f"Adding command: {cmd}")
|
||||
|
||||
logger.info(f"(Re)Loaded {len(rows)} counter commands!")
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
return True
|
||||
|
||||
@@ -59,6 +219,87 @@ class CounterCog(LionCog):
|
||||
f"Loaded {len(self.counters)} counters."
|
||||
)
|
||||
|
||||
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
|
||||
"""
|
||||
Move source profile entries to target profile entries
|
||||
"""
|
||||
results = ["(Counters)"]
|
||||
|
||||
rows = await self.data.CounterEntry.table.update_where(userid=source_profile.profileid).set(userid=target_profile.profileid)
|
||||
if rows:
|
||||
results.append(
|
||||
f"Migrated {len(rows)} counter entries from source profile."
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
"No counter entries to migrate in source profile."
|
||||
)
|
||||
|
||||
return ' '.join(results)
|
||||
|
||||
async def user_profile_migration(self):
|
||||
"""
|
||||
Manual single-use migration method from the old userid format to the new profileid format.
|
||||
"""
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
entries = await self.data.CounterEntry.fetch_where()
|
||||
for entry in entries:
|
||||
if entry.userid > 1000:
|
||||
# Assume userid is a twitch userid
|
||||
profile = await UserProfile.fetch_from_twitchid(self.bot, entry.userid)
|
||||
if not profile:
|
||||
# Need to create
|
||||
users = await self.crocbot.fetch_users(ids=[entry.userid])
|
||||
if not users:
|
||||
continue
|
||||
user = users[0]
|
||||
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||
await entry.update(userid=profile.profileid)
|
||||
logger.info("Completed single-shot user profile migration")
|
||||
|
||||
# General API
|
||||
def twitch_callback(self, callback):
|
||||
"""
|
||||
Generate a Twitch command callback from the given general callback.
|
||||
|
||||
General callback must be of the form
|
||||
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
|
||||
|
||||
Return will be a command callback of the form
|
||||
callback(cog, ctx: Context, *, args: Optional[str] = None)
|
||||
"""
|
||||
async def command_callback(cog: CounterCog, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
profiles = cog.bot.get_cog('ProfileCog')
|
||||
# Compute author profile
|
||||
author = await profiles.fetch_profile_twitch(ctx.author)
|
||||
# Compute community profile
|
||||
community = await profiles.fetch_community_twitch(await ctx.channel.user())
|
||||
return await callback(cog, ctx, ORIGIN.TWITCH, author, community, args)
|
||||
return command_callback
|
||||
|
||||
def discord_callback(self, callback):
|
||||
"""
|
||||
Generate a Discord command callback from the given general callback.
|
||||
|
||||
General callback must be of the form
|
||||
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
|
||||
|
||||
Return will be a command callback of the form
|
||||
callback(cog, ctx: LionContext, *, args: Optional[str] = None)
|
||||
"""
|
||||
cog = self
|
||||
async def command_callback(ctx: LionContext, *, args: Optional[str] = None):
|
||||
profiles = cog.bot.get_cog('ProfileCog')
|
||||
# Compute author profile
|
||||
author = await profiles.fetch_profile_discord(ctx.author)
|
||||
# Compute community profile
|
||||
community = await profiles.fetch_community_discord(ctx.guild)
|
||||
return await callback(cog, ctx, ORIGIN.DISCORD, author, community, args)
|
||||
|
||||
return command_callback
|
||||
|
||||
# Counters API
|
||||
|
||||
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
||||
@@ -80,13 +321,19 @@ class CounterCog(LionCog):
|
||||
if row:
|
||||
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
||||
|
||||
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
||||
async def add_to_counter(
|
||||
self,
|
||||
counter: str, userid: int, value: int,
|
||||
context: Optional[str]=None,
|
||||
details: Optional[str]=None,
|
||||
):
|
||||
row = await self.fetch_counter(counter)
|
||||
return await self.data.CounterEntry.create(
|
||||
counterid=row.counterid,
|
||||
userid=userid,
|
||||
value=value,
|
||||
context_str=context
|
||||
context_str=context,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def leaderboard(self, counter: str, start_time=None):
|
||||
@@ -119,13 +366,70 @@ class CounterCog(LionCog):
|
||||
results = await query
|
||||
return results[0]['counter_total'] if results else 0
|
||||
|
||||
# Manage commands
|
||||
@commands.command()
|
||||
async def countermigration(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
await self.user_profile_migration()
|
||||
await ctx.reply("Counter userid->profileid migration done.")
|
||||
|
||||
# Counters commands
|
||||
@commands.command()
|
||||
async def counterslb(self, ctx: commands.Context, *, periodstr: Optional[str] = None):
|
||||
"""
|
||||
Build a leaderboard of counter totals in the given period.
|
||||
"""
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
author = await profiles.fetch_profile_twitch(ctx.author)
|
||||
userid = author.profileid
|
||||
community = await profiles.fetch_community_twitch(await ctx.channel.user())
|
||||
|
||||
period, start_time = await self.parse_period(community, periodstr or '')
|
||||
|
||||
query = self.data.CounterEntry.table.select_where()
|
||||
query.group_by('counterid')
|
||||
query.select('counterid', counter_total='SUM(value)')
|
||||
query.order_by('counter_total', ORDER.DESC)
|
||||
# query.where(Column('counter_total') > 0)
|
||||
if start_time is not None:
|
||||
query.where(self.data.CounterEntry.created_at >= start_time)
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
query.where(self.data.CounterEntry.userid == userid)
|
||||
user_results = await query
|
||||
|
||||
lb = {result['counterid']: result['counter_total'] for result in results}
|
||||
userlb = {result['counterid']: result['counter_total'] for result in user_results}
|
||||
|
||||
counters = await self.data.Counter.fetch_where(counterid=list(lb.keys()))
|
||||
cmap = {c.counterid: c for c in counters}
|
||||
|
||||
parts = []
|
||||
for cid, ctotal in lb.items():
|
||||
if not ctotal:
|
||||
continue
|
||||
counter = cmap[cid]
|
||||
user_total = userlb.get(cid) or 0
|
||||
|
||||
parts.append(f"{counter.name}: {ctotal}")
|
||||
|
||||
prefix = 'top 10 ' if len(parts) > 10 else ''
|
||||
parts = parts[:10]
|
||||
|
||||
lbstr = '; '.join(parts)
|
||||
await ctx.reply(f"Counters {period.value[-1]} {prefix}leaderboard -- {lbstr}")
|
||||
|
||||
@commands.command()
|
||||
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
|
||||
name = name.lower()
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
author = await profiles.fetch_profile_twitch(ctx.author)
|
||||
userid = author.profileid
|
||||
community = await profiles.fetch_community_twitch(await ctx.channel.user())
|
||||
|
||||
if subcmd is None or subcmd == 'show':
|
||||
# Show
|
||||
@@ -142,23 +446,56 @@ class CounterCog(LionCog):
|
||||
return
|
||||
await self.add_to_counter(
|
||||
name,
|
||||
int(ctx.author.id),
|
||||
userid,
|
||||
value,
|
||||
context='cmd: counter add'
|
||||
)
|
||||
total = await self.totals(name)
|
||||
await ctx.reply(f"'{name}' counter is now: {total}")
|
||||
elif subcmd == 'lb':
|
||||
user = await ctx.channel.user()
|
||||
lbstr = await self.formatted_lb(name, args or '', int(user.id))
|
||||
await ctx.reply(lbstr)
|
||||
await self.show_lb(ctx, name, args or '', author, community, origin=ORIGIN.TWITCH)
|
||||
elif subcmd == 'clear':
|
||||
await self.reset_counter(name)
|
||||
await ctx.reply(f"'{name}' counter reset.")
|
||||
else:
|
||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
||||
elif subcmd == 'alias':
|
||||
splits = args.split(maxsplit=3) if args else []
|
||||
counter = await self.fetch_counter(name)
|
||||
rows = await self.data.CounterCommand.fetch_where(counterid=counter.counterid)
|
||||
existing = rows[0] if rows else None
|
||||
if existing and not args:
|
||||
# Show current alias
|
||||
await ctx.reply(
|
||||
f"Counter '{name}' aliases: '!{existing.name}' to add to counter; "
|
||||
f"'!{existing.lbname}' to view counter leaderboard; "
|
||||
f"'!{existing.undoname}' to undo (your) last addition."
|
||||
)
|
||||
elif len(splits) < 4:
|
||||
# Show usage
|
||||
await ctx.reply(
|
||||
"USAGE: !counter <name> alias <cmdname> <lbname> <undoname> <response> -- "
|
||||
"Response accepts keywords {total}, {period}, {period_name}, {detailsorname}, {user_total}."
|
||||
)
|
||||
else:
|
||||
# Create new alias
|
||||
cmdname, lbname, undoname, response = splits
|
||||
# Remove any existing alias
|
||||
await self.data.CounterCommand.table.delete_where(name=cmdname)
|
||||
|
||||
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
||||
alias = await self.data.CounterCommand.create(
|
||||
name=cmdname,
|
||||
counterid=counter.counterid,
|
||||
lbname=lbname, undoname=undoname, response=response
|
||||
)
|
||||
await self.load_counter_commands()
|
||||
await ctx.reply(
|
||||
f"Alias created for counter '{name}': '!{alias.name}' to add to counter; "
|
||||
f"'!{alias.lbname}' to view counter leaderboard; "
|
||||
f"'!{alias.undoname}' to undo (your) last addition."
|
||||
)
|
||||
else:
|
||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
|
||||
|
||||
async def parse_period(self, community: Community, periodstr: str, default=PERIOD.STREAM):
|
||||
if periodstr:
|
||||
period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
|
||||
if period is None:
|
||||
@@ -172,9 +509,13 @@ class CounterCog(LionCog):
|
||||
if period is PERIOD.ALL:
|
||||
start_time = None
|
||||
elif period is PERIOD.STREAM:
|
||||
streams = await self.crocbot.fetch_streams(user_ids=[userid])
|
||||
if streams:
|
||||
stream = streams[0]
|
||||
twitches = await community.twitch_channels()
|
||||
stream = None
|
||||
if twitches:
|
||||
twitch = twitches[0]
|
||||
streams = await self.crocbot.fetch_streams(user_ids=[int(twitch.channelid)])
|
||||
stream = streams[0] if streams else None
|
||||
if stream:
|
||||
start_time = stream.started_at
|
||||
else:
|
||||
period = PERIOD.ALL
|
||||
@@ -193,100 +534,104 @@ class CounterCog(LionCog):
|
||||
|
||||
return (period, start_time)
|
||||
|
||||
async def formatted_lb(self, counter: str, periodstr: str, channelid: int):
|
||||
@cmds.hybrid_command(
|
||||
name='counterlb',
|
||||
description="Show the leaderboard for the given counter."
|
||||
)
|
||||
async def counterlb_dcmd(self, ctx: LionContext, counter: str, period: Optional[str] = None):
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
author = await profiles.fetch_profile_discord(ctx.author)
|
||||
community = await profiles.fetch_community_discord(ctx.guild)
|
||||
await self.show_lb(ctx, counter, period, author, community, ORIGIN.DISCORD)
|
||||
|
||||
period, start_time = await self.parse_period(channelid, periodstr)
|
||||
@cmds.hybrid_command(
|
||||
name='counterstats',
|
||||
description="Show your stats for the given counter."
|
||||
)
|
||||
async def counterstats_dcmd(self, ctx: LionContext, counter: str, period: Optional[str]=None):
|
||||
profiles = self.bot.get_cog('ProfileCog')
|
||||
author = await profiles.fetch_profile_discord(ctx.author)
|
||||
community = await profiles.fetch_community_discord(ctx.guild)
|
||||
|
||||
if period and period.lower() in ('monthly', 'month'):
|
||||
card = await counter_monthly_card(
|
||||
self.bot,
|
||||
userid=ctx.author.id,
|
||||
profile=author,
|
||||
counter=await self.fetch_counter(counter),
|
||||
guildid=ctx.guild.id,
|
||||
offset=0,
|
||||
)
|
||||
await card.render()
|
||||
await ctx.reply(file=card.as_file('stats.png'))
|
||||
else:
|
||||
card = await counter_weekly_card(
|
||||
self.bot,
|
||||
userid=ctx.author.id,
|
||||
profile=author,
|
||||
counter=await self.fetch_counter(counter),
|
||||
guildid=ctx.guild.id,
|
||||
offset=0,
|
||||
)
|
||||
await card.render()
|
||||
await ctx.reply(file=card.as_file('stats.png'))
|
||||
|
||||
async def show_lb(
|
||||
self,
|
||||
ctx: commands.Context | LionContext,
|
||||
counter: str,
|
||||
periodstr: str,
|
||||
caller: UserProfile,
|
||||
community: Community,
|
||||
origin: ORIGIN = ORIGIN.TWITCH
|
||||
):
|
||||
|
||||
period, start_time = await self.parse_period(community, periodstr)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
if lb:
|
||||
userids = list(lb.keys())
|
||||
users = await self.crocbot.fetch_users(ids=userids)
|
||||
name_map = {user.id: user.display_name for user in users}
|
||||
name_map = {}
|
||||
for userid in lb.keys():
|
||||
profile = await UserProfile.fetch(self.bot, userid)
|
||||
name = await profile.get_name()
|
||||
name_map[userid] = name
|
||||
|
||||
if not lb:
|
||||
await ctx.reply(
|
||||
f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||
)
|
||||
elif origin is ORIGIN.TWITCH:
|
||||
parts = []
|
||||
for userid, total in lb.items():
|
||||
items = list(lb.items())
|
||||
prefix = 'top 10 ' if len(items) > 10 else ''
|
||||
items = items[:10]
|
||||
for userid, total in items:
|
||||
name = name_map.get(userid, str(userid))
|
||||
part = f"{name}: {total}"
|
||||
parts.append(part)
|
||||
lbstr = '; '.join(parts)
|
||||
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
||||
else:
|
||||
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||
await ctx.reply(f"{counter} {period.value[-1]} {prefix}leaderboard --- {lbstr}")
|
||||
elif origin is ORIGIN.DISCORD:
|
||||
title = f"'{counter}' {period.value[-1]} leaderboard"
|
||||
|
||||
# Misc actual counter commands
|
||||
# TODO: Factor this out to a different module...
|
||||
@commands.command()
|
||||
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'tea'
|
||||
lb_strings = []
|
||||
author_index = None
|
||||
max_name_len = min((30, max(len(name) for name in name_map.values())))
|
||||
for i, (uid, total) in enumerate(lb.items()):
|
||||
if author_index is None and uid == caller.profileid:
|
||||
author_index = i
|
||||
lb_strings.append(
|
||||
"{:<{}}\t{:<9}".format(
|
||||
name_map[uid],
|
||||
max_name_len,
|
||||
total,
|
||||
)
|
||||
)
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: tea'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
||||
page_len = 20
|
||||
pages = paginate_list(lb_strings, block_length=page_len, title=title)
|
||||
start_page = author_index // page_len if author_index is not None else 0
|
||||
|
||||
@commands.command()
|
||||
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'coffee'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: coffee'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'water'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: water'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def stuff(self, ctx: commands.Context, *, args: str = ''):
|
||||
await ctx.reply(f"Stuff {args}")
|
||||
|
||||
@cmds.hybrid_command('water')
|
||||
async def d_water_cmd(self, ctx):
|
||||
await ctx.reply(repr(ctx))
|
||||
await pager(
|
||||
ctx,
|
||||
pages,
|
||||
start_at=start_page
|
||||
)
|
||||
|
||||
@@ -10,7 +10,8 @@ class CounterData(Registry):
|
||||
CREATE TABLE counters(
|
||||
counterid SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
category TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||
"""
|
||||
@@ -19,6 +20,7 @@ class CounterData(Registry):
|
||||
|
||||
counterid = Integer(primary=True)
|
||||
name = String()
|
||||
category = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
class CounterEntry(RowModel):
|
||||
@@ -31,7 +33,8 @@ class CounterData(Registry):
|
||||
userid INTEGER NOT NULL,
|
||||
value INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
context_str TEXT
|
||||
context_str TEXT,
|
||||
details TEXT
|
||||
);
|
||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||
"""
|
||||
@@ -44,5 +47,28 @@ class CounterData(Registry):
|
||||
value = Integer()
|
||||
created_at = Timestamp()
|
||||
context_str = String()
|
||||
details = String()
|
||||
|
||||
class CounterCommand(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE counter_commands(
|
||||
name TEXT PRIMARY KEY,
|
||||
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
lbname TEXT,
|
||||
undoname TEXT,
|
||||
response TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
# NOTE: This table will be replaced by aliases soon anyway
|
||||
# So no need to worry about integrity or future-proofing
|
||||
_tablename_ = 'counter_commands'
|
||||
_cache_ = {}
|
||||
|
||||
name = String(primary=True)
|
||||
counterid = Integer()
|
||||
lbname = String()
|
||||
undoname = String()
|
||||
response = String()
|
||||
|
||||
|
||||
0
src/modules/counters/graphics/monthly.py
Normal file
0
src/modules/counters/graphics/monthly.py
Normal file
222
src/modules/counters/graphics/weekly.py
Normal file
222
src/modules/counters/graphics/weekly.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import itertools
|
||||
from typing import Optional
|
||||
from datetime import timedelta, datetime
|
||||
import calendar
|
||||
|
||||
from meta import LionBot
|
||||
from gui.cards import WeeklyStatsCard, MonthlyStatsCard
|
||||
from gui.base import CardMode
|
||||
from modules.profiles.profile import UserProfile
|
||||
from babel import LocalBabel
|
||||
from modules.statistics.lib import apply_month_offset
|
||||
|
||||
from ..data import CounterData
|
||||
|
||||
babel = LocalBabel('counters')
|
||||
_ = babel._
|
||||
|
||||
|
||||
|
||||
async def counter_monthly_card(
|
||||
bot: LionBot,
|
||||
userid: int,
|
||||
profile: UserProfile,
|
||||
counter: CounterData.Counter,
|
||||
guildid: int,
|
||||
offset: int,
|
||||
):
|
||||
cog = bot.get_cog('CounterCog')
|
||||
data: CounterData = cog.data
|
||||
|
||||
if guildid:
|
||||
lion = await bot.core.lions.fetch_member(guildid, userid)
|
||||
user = await lion.fetch_member()
|
||||
else:
|
||||
lion = await bot.core.lions.fetch_user(userid)
|
||||
user = await bot.fetch_user(userid)
|
||||
today = lion.today
|
||||
|
||||
month_start = today.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
target = apply_month_offset(month_start, offset)
|
||||
target_end = (target + timedelta(days=40)).replace(day=1, hour=0, minute=0) - timedelta(days=1)
|
||||
|
||||
months = [target]
|
||||
for i in range(0, 3):
|
||||
months.append((months[-1] - timedelta(days=1)).replace(day=1))
|
||||
months.reverse()
|
||||
|
||||
rows = await data.CounterEntry.fetch_where(
|
||||
data.CounterEntry.counterid == counter.counterid,
|
||||
data.CounterEntry.userid == profile.profileid,
|
||||
data.CounterEntry.created_at <= target_end,
|
||||
data.CounterEntry.created_at >= months[0],
|
||||
)
|
||||
|
||||
events = [(row.created_at, row.value) for row in rows]
|
||||
|
||||
month_lengths = [
|
||||
(calendar.monthrange(month.year, month.month)[1]) for month in months
|
||||
]
|
||||
month_dates = []
|
||||
for month, length in zip(months, month_lengths):
|
||||
for day in range(1, length + 1):
|
||||
month_dates.append(datetime(month.year, month.month, day, tzinfo=month.tzinfo))
|
||||
|
||||
monthly_flat = events_to_dayfreq(events, month_dates)
|
||||
print(monthly_flat)
|
||||
|
||||
monthly = []
|
||||
i = 0
|
||||
for length in month_lengths:
|
||||
this_month = monthly_flat[i : i+length]
|
||||
i += length
|
||||
monthly.append(this_month)
|
||||
|
||||
|
||||
skin = await bot.get_cog('CustomSkinCog').get_skinargs_for(
|
||||
guildid, userid, MonthlyStatsCard.card_id
|
||||
)
|
||||
skin |= {
|
||||
'title_text': f"{counter.name.upper()}",
|
||||
'this_month_text': f"THIS MONTH: {{amount}} {counter.name.upper()}",
|
||||
'last_month_text': f"LAST MONTH: {{amount}} {counter.name.upper()}"
|
||||
}
|
||||
|
||||
if user:
|
||||
username = (user.display_name, '')
|
||||
else:
|
||||
username = (await profile.get_name(), '')
|
||||
|
||||
|
||||
card = MonthlyStatsCard(
|
||||
user=username,
|
||||
timezone=str(lion.timezone),
|
||||
now=lion.now.timestamp(),
|
||||
month=int(target.timestamp()),
|
||||
monthly=monthly,
|
||||
current_streak=-1,
|
||||
longest_streak=-1,
|
||||
skin=skin | {'mode': CardMode.TEXT}
|
||||
)
|
||||
return card
|
||||
|
||||
|
||||
|
||||
|
||||
async def counter_weekly_card(
|
||||
bot: LionBot,
|
||||
userid: int,
|
||||
profile: UserProfile,
|
||||
counter: CounterData.Counter,
|
||||
guildid: int,
|
||||
offset: int,
|
||||
):
|
||||
cog = bot.get_cog('CounterCog')
|
||||
data: CounterData = cog.data
|
||||
|
||||
if guildid:
|
||||
lion = await bot.core.lions.fetch_member(guildid, userid)
|
||||
user = await lion.fetch_member()
|
||||
else:
|
||||
lion = await bot.core.lions.fetch_user(userid)
|
||||
user = await bot.fetch_user(userid)
|
||||
today = lion.today
|
||||
week_start = today - timedelta(days=today.weekday()) - timedelta(weeks=offset)
|
||||
days = [week_start + timedelta(i) for i in range(-7, 8 if offset else (today.weekday() + 2))]
|
||||
|
||||
rows = await data.CounterEntry.fetch_where(
|
||||
data.CounterEntry.counterid == counter.counterid,
|
||||
data.CounterEntry.userid == profile.profileid,
|
||||
data.CounterEntry.created_at <= days[-1],
|
||||
data.CounterEntry.created_at >= days[0],
|
||||
)
|
||||
|
||||
events = [(row.created_at, row.value) for row in rows]
|
||||
|
||||
daily = events_to_dayfreq(events, days)
|
||||
sessions = events_to_sessions(next(zip(*events), []))
|
||||
|
||||
skin = await bot.get_cog('CustomSkinCog').get_skinargs_for(
|
||||
guildid, userid, WeeklyStatsCard.card_id
|
||||
)
|
||||
skin |= {
|
||||
'title_text': f"{counter.name.upper()}",
|
||||
'this_week_text': f"THIS WEEK: {{amount}} {counter.name.upper()}",
|
||||
'last_week_text': f"LAST WEEK: {{amount}} {counter.name.upper()}"
|
||||
}
|
||||
|
||||
if user:
|
||||
username = (user.display_name, '')
|
||||
else:
|
||||
username = (await profile.get_name(), '')
|
||||
|
||||
|
||||
card = WeeklyStatsCard(
|
||||
user=username,
|
||||
timezone=str(lion.timezone),
|
||||
now=lion.now.timestamp(),
|
||||
week=week_start.timestamp(),
|
||||
daily=tuple(map(int, daily)),
|
||||
sessions=sessions,
|
||||
skin=skin | {'mode': CardMode.TEXT}
|
||||
)
|
||||
return card
|
||||
|
||||
|
||||
|
||||
def events_to_dayfreq(events: list[tuple[datetime, int]], days: list[datetime]) -> list[int]:
|
||||
if not days:
|
||||
return []
|
||||
|
||||
last_day = 0
|
||||
dayts = 0
|
||||
|
||||
daymap = {}
|
||||
for day in sorted(days, reverse=True):
|
||||
dayts = day.timestamp()
|
||||
last_day = last_day or (day + timedelta(days=1)).timestamp()
|
||||
daymap[dayts] = 0
|
||||
|
||||
first_day = dayts
|
||||
|
||||
for tim, count in events:
|
||||
timts = tim.timestamp()
|
||||
if not first_day < timts < last_day:
|
||||
continue
|
||||
|
||||
for day_start in daymap:
|
||||
if timts > day_start:
|
||||
daymap[day_start] += count
|
||||
break
|
||||
|
||||
return list(reversed(daymap.values()))
|
||||
|
||||
|
||||
def events_to_sessions(event_times: list[datetime]) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Convert a provided list of event times to a session list.
|
||||
"""
|
||||
sessions = []
|
||||
|
||||
session_start = None
|
||||
session_end = None
|
||||
|
||||
SESSION_GAP = 60 * 30
|
||||
SESSION_RADIUS = 60 * 30
|
||||
|
||||
for time in sorted(event_times):
|
||||
if session_start and session_end and (time - session_end).total_seconds() - SESSION_RADIUS > SESSION_GAP:
|
||||
session = (int(session_start.timestamp()), int(session_end.timestamp()))
|
||||
sessions.append(session)
|
||||
session_start = None
|
||||
session_end = None
|
||||
|
||||
if session_start is None:
|
||||
session_start = time - timedelta(seconds=SESSION_RADIUS)
|
||||
session_end = time + timedelta(seconds=SESSION_RADIUS)
|
||||
|
||||
if session_start and session_end:
|
||||
session = (int(session_start.timestamp()), int(session_end.timestamp()))
|
||||
sessions.append(session)
|
||||
|
||||
return sessions
|
||||
9
src/modules/counters/migration.sql
Normal file
9
src/modules/counters/migration.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
ALTER TABLE counters ADD COLUMN category TEXT;
|
||||
ALTER TABLE counter_log ADD COLUMN details TEXT;
|
||||
CREATE TABLE counter_commands(
|
||||
name TEXT PRIMARY KEY,
|
||||
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
lbname TEXT,
|
||||
undoname TEXT,
|
||||
response TEXT NOT NULL
|
||||
);
|
||||
0
src/modules/counters/ui/leaderboard.py
Normal file
0
src/modules/counters/ui/leaderboard.py
Normal file
0
src/modules/counters/ui/stats.py
Normal file
0
src/modules/counters/ui/stats.py
Normal file
8
src/modules/hyperfocus/__init__.py
Normal file
8
src/modules/hyperfocus/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import HyperFocusCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(HyperFocusCog(bot))
|
||||
306
src/modules/hyperfocus/cog.py
Normal file
306
src/modules/hyperfocus/cog.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
from twitchAPI.type import AuthScope
|
||||
import random
|
||||
import datetime as dt
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||
from meta.sockets import Channel, register_channel
|
||||
from utils.lib import strfdelta, utc_now
|
||||
from . import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class FocusState:
|
||||
userid: int | str
|
||||
name: str
|
||||
focus_ends: datetime
|
||||
hyper: bool = True
|
||||
|
||||
|
||||
class FocusChannel(Channel):
|
||||
name = 'FocusList'
|
||||
|
||||
def __init__(self, cog: 'HyperFocusCog', **kwargs):
|
||||
self.cog = cog
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def on_connection(self, websocket, event):
|
||||
await super().on_connection(websocket, event)
|
||||
await self.reload_focus(websocket=websocket)
|
||||
|
||||
def focus_args(self, state: FocusState):
|
||||
return (
|
||||
state.userid,
|
||||
state.name,
|
||||
state.hyper,
|
||||
state.focus_ends.isoformat(),
|
||||
)
|
||||
|
||||
async def reload_focus(self, websocket=None):
|
||||
"""
|
||||
Clear tasklist and re-send current tasks.
|
||||
"""
|
||||
await self.send_clear(websocket=websocket)
|
||||
for state in self.cog.hyperfocusing.values():
|
||||
await self.send_set(*self.focus_args(state), websocket=websocket)
|
||||
|
||||
async def send_set(self, userid, name, hyper, end_at, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "setFocus",
|
||||
'args': {
|
||||
'userid': userid,
|
||||
'name': name,
|
||||
'hyper': hyper,
|
||||
'end_at': end_at,
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
async def send_del(self, userid, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "delFocus",
|
||||
'args': {
|
||||
'userid': userid,
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
async def send_clear(self, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "clearFocus",
|
||||
'args': {
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
|
||||
|
||||
class HyperFocusCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
|
||||
# userid -> timestamp when they stop
|
||||
self.hyperfocusing: dict[str, FocusState] = {}
|
||||
|
||||
self.channel = FocusChannel(self)
|
||||
register_channel(self.channel.name, self.channel)
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
async def cog_load(self):
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
self.load_hyperfocus()
|
||||
self.loaded.set()
|
||||
|
||||
async def cog_unload(self):
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
def save_hyperfocus(self):
|
||||
with open('hyperfocus.json', 'w', encoding='utf-8') as f:
|
||||
mapped = {
|
||||
userid: {
|
||||
'userid': str(state.userid),
|
||||
'name': state.name,
|
||||
'focus_ends': state.focus_ends.isoformat(),
|
||||
'hyper': state.hyper
|
||||
}
|
||||
for userid, state in self.hyperfocusing.items()
|
||||
}
|
||||
json.dump(mapped, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def load_hyperfocus(self):
|
||||
with open('hyperfocus.json') as f:
|
||||
mapped = json.load(f)
|
||||
self.hyperfocusing.clear()
|
||||
for userid, map in mapped.items():
|
||||
self.hyperfocusing[str(userid)] = FocusState(
|
||||
userid=str(map['userid']),
|
||||
name=map['name'],
|
||||
hyper=map['hyper'],
|
||||
focus_ends=dt.datetime.fromisoformat(map['focus_ends'])
|
||||
)
|
||||
print(f"Loaded hyperfocus: {self.hyperfocusing}")
|
||||
|
||||
def check_hyperfocus(self, userid):
|
||||
"""
|
||||
Returns whether a user is currently in HYPERFOCUS mode!
|
||||
"""
|
||||
return (state := self.hyperfocusing.get(userid, None)) and utc_now() < state.focus_ends
|
||||
|
||||
@commands.Cog.event('event_message')
|
||||
async def on_message(self, message: twitchio.Message):
|
||||
if message.content and message.content.lower() == 'nice':
|
||||
await message.channel.send("That's Nice")
|
||||
|
||||
await self.good_croccy_handler(message)
|
||||
|
||||
tags = message.tags
|
||||
if tags and message.content and self.check_hyperfocus(tags.get('user-id')):
|
||||
if not self.valid_focus_message(message):
|
||||
logger.info(
|
||||
f"Deleting message from hyperfocused user. {message.raw_data=}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
msgid = tags['id']
|
||||
# TODO: Better selection for moderator
|
||||
# i.e. if the message is not from the broadcaster and we do have delete perms
|
||||
# then use our own token.
|
||||
broadcasterid = tags['room-id']
|
||||
authcog = self.bot.get_cog('TwitchAuthCog')
|
||||
if not await authcog.check_auth(broadcasterid, scopes=[AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES]):
|
||||
await message.channel.send(f"@{message.author.name} Stay focused! (I tried to delete your message because you are in !hyperfocus. Unfortunately I don't have the permissions to do that. But stay focused anyway!)")
|
||||
else:
|
||||
twitch = await authcog.fetch_client_for(broadcasterid)
|
||||
await twitch.delete_chat_message(
|
||||
broadcasterid,
|
||||
broadcasterid,
|
||||
msgid,
|
||||
)
|
||||
await message.channel.send(
|
||||
f"@{message.author.name} Stay focused! (I deleted your message because you are in !hyperfocus, use !unfocus to come back.)"
|
||||
)
|
||||
|
||||
async def good_croccy_handler(self, message: twitchio.Message):
|
||||
if not message.content:
|
||||
return
|
||||
cleaned = message.content.lower().replace('@croccyhelper', '').strip()
|
||||
if cleaned in ('good croc', 'good croccy', 'good helper'):
|
||||
await message.channel.send("holono1Heart")
|
||||
elif cleaned in ('bad croc', 'bad croccy', 'bad helper'):
|
||||
await message.channel.send("holono1Sad")
|
||||
|
||||
async def chemical_handler(self, message: twitchio.Message):
|
||||
if not message.content:
|
||||
return
|
||||
cleaned = message.content.lower().strip()
|
||||
if cleaned in ('oh',):
|
||||
await message.channel.send('Oxygen Hydrogen!')
|
||||
|
||||
def valid_focus_message(self, message: twitchio.Message) -> bool:
|
||||
"""
|
||||
Determined whether the given message is allowed to be sent in !hyperfocus.
|
||||
That is, if it appears to be emote-only or a command.
|
||||
"""
|
||||
|
||||
content = message.content
|
||||
if not content:
|
||||
return True
|
||||
|
||||
tags = message.tags or {}
|
||||
to_remove = []
|
||||
|
||||
if (replying := tags.get('reply-parent-user-login', '')) and content.startswith('@'):
|
||||
# Trim the mention from the start of the content
|
||||
splits = content.split(maxsplit=1)
|
||||
to_remove.append((0, len(splits[0])))
|
||||
|
||||
if emotesstr := tags.get('emotes', ''):
|
||||
for emotestr in emotesstr.split('/'):
|
||||
emote, locs = emotestr.split(':')
|
||||
for loc in locs.split(','):
|
||||
start, end = loc.split('-')
|
||||
to_remove.append((int(start), int(end) + 1))
|
||||
|
||||
# Sort the pairs to remove by descending starting index
|
||||
# This should allow clean removal with a loop as long as there are no intersections.
|
||||
to_remove.sort(key=lambda pair: pair[0], reverse=True)
|
||||
for start, end in to_remove:
|
||||
content = content[:start] + content[end:]
|
||||
content = content.strip().replace(' ', '').replace('\n', '')
|
||||
allowed = not content or content.startswith('!') or content.startswith('*')
|
||||
allowed = allowed or all(not char.isascii() for char in content)
|
||||
|
||||
if not allowed:
|
||||
logger.info(f"Invalid hyperfocus message. Trimmed content: {content}")
|
||||
|
||||
return allowed
|
||||
|
||||
@commands.command(name='coinflip')
|
||||
async def coinflip(self, ctx):
|
||||
await ctx.reply(random.choice(('heads', 'tails')))
|
||||
|
||||
@commands.command(name='choose')
|
||||
async def choose(self, ctx, *, args: str):
|
||||
if not args:
|
||||
await ctx.reply("Give me something to choose, e.g. !choose Heads | Tails")
|
||||
else:
|
||||
options = args.split('|')
|
||||
options = [option.strip() for option in options]
|
||||
options = [option for option in options if option]
|
||||
choice = random.choice(options)
|
||||
if random.random() < 0.01:
|
||||
choice = "You"
|
||||
await ctx.reply(f"I choose: {choice}")
|
||||
|
||||
@commands.command(name='hyperfocus')
|
||||
async def hyperfocus_cmd(self, ctx, dur: Optional[int] = None):
|
||||
userid = str(ctx.author.id)
|
||||
now = utc_now()
|
||||
end_time = None
|
||||
|
||||
if dur is None:
|
||||
# Automatically select time
|
||||
next_hour = now.replace(minute=0, second=0, microsecond=0) + dt.timedelta(hours=1)
|
||||
next_block = next_hour - dt.timedelta(minutes=10)
|
||||
if now > next_block:
|
||||
# Currently in the break
|
||||
next_block = next_block + dt.timedelta(hours=1)
|
||||
end_time = next_block
|
||||
dur = int((end_time - now).total_seconds() // 60)
|
||||
elif dur > 720:
|
||||
await ctx.reply("You can hyperfocus for at most 12 hours at a time!")
|
||||
else:
|
||||
end_time = utc_now() + dt.timedelta(minutes=dur)
|
||||
|
||||
if end_time is not None:
|
||||
state = self.hyperfocusing[userid] = FocusState(
|
||||
userid=userid,
|
||||
name=ctx.author.display_name,
|
||||
focus_ends=end_time,
|
||||
)
|
||||
self.save_hyperfocus()
|
||||
await self.channel.send_set(*self.channel.focus_args(state))
|
||||
await ctx.reply(
|
||||
f"{ctx.author.name} has gone into HYPERFOCUS mode! "
|
||||
f"They will be in emote and command only mode for the next {dur} minutes! "
|
||||
"Use !unfocus if you really need to chat before then, best of luck! 🍀"
|
||||
)
|
||||
|
||||
@commands.command(name='unfocus')
|
||||
async def unfocus_cmd(self, ctx):
|
||||
self.hyperfocusing.pop(ctx.author.id, None)
|
||||
self.save_hyperfocus()
|
||||
await self.channel.send_del(ctx.author.id)
|
||||
await ctx.reply("Welcome back from focus, hope it went well! Have a comfy break and remember to have a sippie and a stretch~")
|
||||
|
||||
@commands.command(name='hyperfocused')
|
||||
async def focused_cmd(self, ctx, user: Optional[twitchio.User] = None):
|
||||
user = user if user is not None else ctx.author
|
||||
userid = str(user.id)
|
||||
if self.check_hyperfocus(userid):
|
||||
state = self.hyperfocusing.get(userid)
|
||||
end_time = state.focus_ends
|
||||
durstr = strfdelta(end_time - utc_now())
|
||||
await ctx.reply(
|
||||
f"{user.name} is in HYPERFOCUS for another {durstr}! "
|
||||
"They can only write emojis and commands in this time~ "
|
||||
"(use !unfocus to come back if you need to!) "
|
||||
"Good luck!"
|
||||
)
|
||||
elif userid != str(ctx.author.id):
|
||||
await ctx.reply(
|
||||
f"{user.name} is not hyperfocused!"
|
||||
)
|
||||
else:
|
||||
await ctx.reply(
|
||||
"You are not hyperfocused! "
|
||||
"Enter HYPERFOCUS mode for e.g. 10 minutes by writing !hyperfocus 10"
|
||||
)
|
||||
@@ -4,17 +4,21 @@ import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from attr import dataclass
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot, LionCog
|
||||
from meta.LionBot import LionBot
|
||||
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||
from meta.sockets import Channel, register_channel
|
||||
from utils.lib import strfdelta, utc_now
|
||||
from . import logger
|
||||
from .data import NowListData
|
||||
|
||||
from modules.profiles.profile import UserProfile
|
||||
|
||||
|
||||
class NowDoingChannel(Channel):
|
||||
name = 'NowList'
|
||||
@@ -25,19 +29,7 @@ class NowDoingChannel(Channel):
|
||||
|
||||
async def on_connection(self, websocket, event):
|
||||
await super().on_connection(websocket, event)
|
||||
for task in self.cog.tasks.values():
|
||||
await self.send_set(*self.task_args(task), websocket=websocket)
|
||||
|
||||
async def send_test_set(self):
|
||||
tasks = [
|
||||
(0, 'Tester0', "Testing Tasklist", True),
|
||||
(1, 'Tester1', "Getting Confused", False),
|
||||
(2, "Tester2", "Generating Bugs", True),
|
||||
(3, "Tester3", "Fixing Bugs", False),
|
||||
(4, "Tester4", "Pushing the red button", False),
|
||||
]
|
||||
for task in tasks:
|
||||
await self.send_set(*task)
|
||||
await self.reload_tasklist(websocket=websocket)
|
||||
|
||||
def task_args(self, task: NowListData.Task):
|
||||
return (
|
||||
@@ -48,6 +40,14 @@ class NowDoingChannel(Channel):
|
||||
task.done_at.isoformat() if task.done_at else None,
|
||||
)
|
||||
|
||||
async def reload_tasklist(self, websocket=None):
|
||||
"""
|
||||
Clear tasklist and re-send current tasks.
|
||||
"""
|
||||
await self.send_clear(websocket=websocket)
|
||||
for task in self.cog.tasks.values():
|
||||
await self.send_set(*self.task_args(task), websocket=websocket)
|
||||
|
||||
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
@@ -61,28 +61,28 @@ class NowDoingChannel(Channel):
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
async def send_del(self, userid):
|
||||
async def send_del(self, userid, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "delTask",
|
||||
'args': {
|
||||
'userid': userid,
|
||||
}
|
||||
})
|
||||
}, websocket=websocket)
|
||||
|
||||
async def send_clear(self):
|
||||
async def send_clear(self, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "clearTasks",
|
||||
'args': {
|
||||
}
|
||||
})
|
||||
}, websocket=websocket)
|
||||
|
||||
|
||||
class NowDoingCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot = bot.crocbot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
self.data = bot.db.load_registry(NowListData())
|
||||
self.channel = NowDoingChannel(self)
|
||||
register_channel(self.channel.name, self.channel)
|
||||
@@ -94,17 +94,82 @@ class NowDoingCog(LionCog):
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
await self.load_tasks()
|
||||
|
||||
self.bot.get_cog('ProfileCog').add_profile_migrator(self.migrate_profiles, name='task-migrator')
|
||||
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
self.loaded.set()
|
||||
|
||||
async def cog_unload(self):
|
||||
self.loaded.clear()
|
||||
self.tasks.clear()
|
||||
if profiles := self.bot.get_cog('ProfileCog'):
|
||||
profiles.del_profile_migrator('task-migrator')
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
|
||||
"""
|
||||
Move current source task to target profile if there's room for it, otherwise annihilate
|
||||
"""
|
||||
await self.load_tasks()
|
||||
source_task = self.tasks.pop(source_profile.profileid, None)
|
||||
|
||||
results = ["(Tasklist)"]
|
||||
|
||||
if source_task:
|
||||
target_task = self.tasks.get(target_profile.profileid, None)
|
||||
if target_task and (target_task.is_done or target_task.started_at < source_task.started_at):
|
||||
# If target is done, remove it so we can overwrite
|
||||
results.append("Removed older task from target profile.")
|
||||
await target_task.delete()
|
||||
target_task = None
|
||||
|
||||
if not target_task:
|
||||
# Update source task with new profile id
|
||||
await source_task.update(userid=target_profile.profileid)
|
||||
target_task = source_task
|
||||
await self.channel.send_set(*self.channel.task_args(target_task))
|
||||
results.append("Migrated 1 currently running task from source profile.")
|
||||
else:
|
||||
# If there is a target task we can't overwrite, just delete the source task
|
||||
await source_task.delete()
|
||||
results.append("Ignoring and removing older task from source profile.")
|
||||
|
||||
self.tasks.pop(source_profile.profileid, None)
|
||||
await self.channel.send_del(source_profile.profileid)
|
||||
else:
|
||||
results.append("No running task in source profile, nothing to migrate!")
|
||||
await self.load_tasks()
|
||||
|
||||
return ' '.join(results)
|
||||
|
||||
async def user_profile_migration(self):
|
||||
"""
|
||||
Manual single-use migration method from the old userid format to the new profileid format.
|
||||
"""
|
||||
await self.load_tasks()
|
||||
for userid, task in self.tasks.items():
|
||||
userid = int(userid)
|
||||
if userid > 1000:
|
||||
# Assume it is a twitch userid
|
||||
profile = await UserProfile.fetch_from_twitchid(self.bot, userid)
|
||||
|
||||
if not profile:
|
||||
# Create a new profile with this twitch user
|
||||
users = await self.crocbot.fetch_users(ids=[userid])
|
||||
if not users:
|
||||
continue
|
||||
user = users[0]
|
||||
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||
|
||||
if not await self.data.Task.fetch(profile.profileid):
|
||||
await task.update(userid=profile.profileid)
|
||||
else:
|
||||
await task.delete()
|
||||
await self.load_tasks()
|
||||
await self.channel.reload_tasklist()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
if not self.loaded.is_set():
|
||||
await ctx.reply("Tasklists are still loading! Please wait a moment~")
|
||||
@@ -123,25 +188,27 @@ class NowDoingCog(LionCog):
|
||||
# await self.channel.send_test_set()
|
||||
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
|
||||
# await ctx.send(str(list(self.tasks.items())[0]))
|
||||
await self.user_profile_migration()
|
||||
await ctx.send(str(ctx.author.id))
|
||||
await ctx.reply("Userid -> profile migration done.")
|
||||
else:
|
||||
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
|
||||
|
||||
@commands.command(aliases=['task', 'check'])
|
||||
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
userid = int(ctx.author.id)
|
||||
async def now(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str] = None, edit=False):
|
||||
args = args.strip() if args else None
|
||||
userid = profile.profileid
|
||||
if args:
|
||||
existing = self.tasks.get(userid, None)
|
||||
await self.data.Task.table.delete_where(userid=userid)
|
||||
task = await self.data.Task.create(
|
||||
userid=userid,
|
||||
name=ctx.author.display_name,
|
||||
name=await profile.get_name(),
|
||||
task=args,
|
||||
started_at=utc_now(),
|
||||
started_at=existing.started_at if (existing and edit) else utc_now(),
|
||||
)
|
||||
self.tasks[task.userid] = task
|
||||
await self.channel.send_set(*self.channel.task_args(task))
|
||||
await ctx.send(f"Updated your current task, good luck!")
|
||||
await ctx.send("Updated your current task, good luck!")
|
||||
elif task := self.tasks.get(userid, None):
|
||||
if task.is_done:
|
||||
done_ago = strfdelta(utc_now() - task.done_at)
|
||||
@@ -159,9 +226,38 @@ class NowDoingCog(LionCog):
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command(name='next')
|
||||
async def nownext(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
userid = int(ctx.author.id)
|
||||
@commands.command(
|
||||
name='now',
|
||||
aliases=['task', 'check']
|
||||
)
|
||||
async def twi_now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
await self.now(ctx, profile, args)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name='now',
|
||||
aliases=['task', 'check']
|
||||
)
|
||||
async def disc_now(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
await self.now(ctx, profile, args)
|
||||
|
||||
@commands.command(
|
||||
name='edit',
|
||||
)
|
||||
async def twi_edit(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
await self.now(ctx, profile, args, edit=True)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name='edit',
|
||||
)
|
||||
async def disc_edit(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
await self.now(ctx, profile, args, edit=True)
|
||||
|
||||
async def nownext(self, ctx: commands.Context | LionContext, profile: UserProfile, args: Optional[str]):
|
||||
userid = profile.profileid
|
||||
task = self.tasks.get(userid, None)
|
||||
if args:
|
||||
if task:
|
||||
@@ -176,13 +272,13 @@ class NowDoingCog(LionCog):
|
||||
await self.data.Task.table.delete_where(userid=userid)
|
||||
task = await self.data.Task.create(
|
||||
userid=userid,
|
||||
name=ctx.author.display_name,
|
||||
name=await profile.get_name(),
|
||||
task=args,
|
||||
started_at=utc_now(),
|
||||
)
|
||||
self.tasks[task.userid] = task
|
||||
await self.channel.send_set(*self.channel.task_args(task))
|
||||
await ctx.send(f"Next task set, good luck!" + ' ' + prefix)
|
||||
await ctx.send("Next task set, good luck!" + ' ' + prefix)
|
||||
elif task:
|
||||
if task.is_done:
|
||||
done_ago = strfdelta(utc_now() - task.done_at)
|
||||
@@ -200,9 +296,22 @@ class NowDoingCog(LionCog):
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command()
|
||||
async def done(self, ctx: commands.Context):
|
||||
userid = int(ctx.author.id)
|
||||
@commands.command(
|
||||
name='next',
|
||||
)
|
||||
async def twi_next(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
await self.nownext(ctx, profile, args)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name='next',
|
||||
)
|
||||
async def disc_next(self, ctx: LionContext, *, args: Optional[str] = None):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
await self.nownext(ctx, profile, args)
|
||||
|
||||
async def done(self, ctx: commands.Context | LionContext, profile: UserProfile):
|
||||
userid = profile.profileid
|
||||
if task := self.tasks.get(userid, None):
|
||||
if task.is_done:
|
||||
await ctx.send(
|
||||
@@ -222,9 +331,36 @@ class NowDoingCog(LionCog):
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command()
|
||||
async def clear(self, ctx: commands.Context):
|
||||
userid = int(ctx.author.id)
|
||||
@commands.command(
|
||||
name='done',
|
||||
)
|
||||
async def twi_done(self, ctx: commands.Context):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
await self.done(ctx, profile)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name='done',
|
||||
)
|
||||
async def disc_done(self, ctx: LionContext):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
await self.done(ctx, profile)
|
||||
|
||||
@commands.command(
|
||||
name='clear',
|
||||
)
|
||||
async def twi_clear(self, ctx: commands.Context):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
await self.clear(ctx, profile)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name='clear',
|
||||
)
|
||||
async def disc_clear(self, ctx: LionContext):
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_discord(ctx.author)
|
||||
await self.clear(ctx, profile)
|
||||
|
||||
async def clear(self, ctx: commands.Context | LionContext, profile):
|
||||
userid = profile.profileid
|
||||
if task := self.tasks.pop(userid, None):
|
||||
await task.delete()
|
||||
await self.channel.send_del(userid)
|
||||
|
||||
@@ -47,16 +47,32 @@ class TimerChannel(Channel):
|
||||
super().__init__(**kwargs)
|
||||
self.cog = cog
|
||||
|
||||
self.channelid = 1261999440160624734
|
||||
self.goal = 12
|
||||
|
||||
async def on_connection(self, websocket, event):
|
||||
await super().on_connection(websocket, event)
|
||||
timer = self.cog.get_channel_timer(1261999440160624734)
|
||||
if timer is not None:
|
||||
await self.send_set(
|
||||
timer.data.last_started,
|
||||
timer.data.focus_length,
|
||||
timer.data.break_length,
|
||||
websocket=websocket,
|
||||
)
|
||||
await self.send_set(
|
||||
**await self.get_args_for(self.channelid),
|
||||
goal=self.goal,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
async def send_updates(self):
|
||||
await self.send_set(
|
||||
**await self.get_args_for(self.channelid),
|
||||
goal=self.goal,
|
||||
)
|
||||
|
||||
async def get_args_for(self, channelid):
|
||||
timer = self.cog.get_channel_timer(channelid)
|
||||
if timer is None:
|
||||
raise ValueError(f"Timer {channelid} doesn't exist.")
|
||||
return {
|
||||
'start_at': timer.data.last_started,
|
||||
'focus_length': timer.data.focus_length,
|
||||
'break_length': timer.data.break_length,
|
||||
}
|
||||
|
||||
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
|
||||
await self.send_event({
|
||||
@@ -304,8 +320,6 @@ class TimerCog(LionCog):
|
||||
return
|
||||
if member.bot:
|
||||
return
|
||||
if 1148167212901859328 not in [role.id for role in member.roles]:
|
||||
return
|
||||
|
||||
# If a member is leaving or joining a running timer, trigger a status update
|
||||
if before.channel != after.channel:
|
||||
@@ -315,6 +329,7 @@ class TimerCog(LionCog):
|
||||
tasks = []
|
||||
if leaving is not None:
|
||||
tasks.append(asyncio.create_task(leaving.update_status_card()))
|
||||
leaving.last_seen.pop(member.id, None)
|
||||
if joining is not None:
|
||||
joining.last_seen[member.id] = utc_now()
|
||||
if not joining.running and joining.auto_restart:
|
||||
@@ -1059,8 +1074,18 @@ class TimerCog(LionCog):
|
||||
@low_management_ward
|
||||
async def streamtimer_update_cmd(self, ctx: LionContext,
|
||||
new_start: Optional[str] = None,
|
||||
new_goal: int = 12):
|
||||
timer = self.get_channel_timer(1261999440160624734)
|
||||
new_goal: Optional[int] = None,
|
||||
new_channel: Optional[discord.VoiceChannel] = None,
|
||||
):
|
||||
if new_channel is not None:
|
||||
channelid = self.channel.channelid = new_channel.id
|
||||
else:
|
||||
channelid = self.channel.channelid
|
||||
|
||||
if new_goal is not None:
|
||||
self.channel.goal = new_goal
|
||||
|
||||
timer = self.get_channel_timer(channelid)
|
||||
if timer is None:
|
||||
return
|
||||
if new_start:
|
||||
@@ -1068,10 +1093,5 @@ class TimerCog(LionCog):
|
||||
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
|
||||
await timer.data.update(last_started=start_at)
|
||||
|
||||
await self.channel.send_set(
|
||||
timer.data.last_started,
|
||||
timer.data.focus_length,
|
||||
timer.data.break_length,
|
||||
goal=new_goal,
|
||||
)
|
||||
await self.channel.send_updates()
|
||||
await ctx.reply("Stream Timer Updated")
|
||||
|
||||
@@ -8,10 +8,12 @@ from gui.cards import FocusTimerCard, BreakTimerCard
|
||||
if TYPE_CHECKING:
|
||||
from .timer import Timer, Stage
|
||||
from tracking.voice.cog import VoiceTrackerCog
|
||||
from modules.nowdoing.cog import NowDoingCog
|
||||
|
||||
|
||||
async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
|
||||
voicecog: 'VoiceTrackerCog' = bot.get_cog('VoiceTrackerCog')
|
||||
nowcog: 'NowDoingCog' = bot.get_cog('NowDoingCog')
|
||||
|
||||
name = timer.base_name
|
||||
if stage is not None:
|
||||
@@ -23,16 +25,22 @@ async def get_timer_card(bot: LionBot, timer: 'Timer', stage: 'Stage'):
|
||||
card_users = []
|
||||
guildid = timer.data.guildid
|
||||
for member in timer.members:
|
||||
if voicecog is not None:
|
||||
session = voicecog.get_session(guildid, member.id)
|
||||
tag = session.tag
|
||||
if session.start_time:
|
||||
session_duration = (utc_now() - session.start_time).total_seconds()
|
||||
else:
|
||||
session_duration = 0
|
||||
profile = await bot.get_cog('ProfileCog').fetch_profile_discord(member)
|
||||
task = nowcog.tasks.get(profile.profileid, None)
|
||||
tag = ''
|
||||
session_duration = 0
|
||||
|
||||
if task:
|
||||
tag = task.task
|
||||
session_duration = ((task.done_at or utc_now()) - task.started_at).total_seconds()
|
||||
else:
|
||||
session_duration = 0
|
||||
tag = None
|
||||
session = voicecog.get_session(guildid, member.id)
|
||||
if session:
|
||||
tag = session.tag
|
||||
if session.start_time:
|
||||
session_duration = (utc_now() - session.start_time).total_seconds()
|
||||
else:
|
||||
session_duration = 0
|
||||
|
||||
card_user = (
|
||||
(member.id, (member.avatar or member.default_avatar).key),
|
||||
|
||||
8
src/modules/profiles/__init__.py
Normal file
8
src/modules/profiles/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import ProfileCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(ProfileCog(bot))
|
||||
404
src/modules/profiles/cog.py
Normal file
404
src/modules/profiles/cog.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional, overload
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from discord import app_commands as appcmds
|
||||
from discord.ext import commands as cmds
|
||||
from twitchAPI.type import AuthScope
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
from twitchio import User
|
||||
from twitchAPI.object.api import TwitchUser
|
||||
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import LionCog, LionBot, CrocBot, LionContext
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
from .profile import UserProfile
|
||||
from .community import Community
|
||||
|
||||
|
||||
class ProfileCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
assert bot.crocbot is not None
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
self.data = bot.db.load_registry(ProfileData())
|
||||
|
||||
self._profile_migrators = {}
|
||||
self._comm_migrators = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
return True
|
||||
|
||||
# Profile API
|
||||
def add_profile_migrator(self, migrator, name=None):
|
||||
name = name or migrator.__name__
|
||||
self._profile_migrators[name or migrator.__name__] = migrator
|
||||
|
||||
logger.info(
|
||||
f"Added user profile migrator {name}: {migrator}"
|
||||
)
|
||||
return migrator
|
||||
|
||||
def del_profile_migrator(self, name: str):
|
||||
migrator = self._profile_migrators.pop(name, None)
|
||||
|
||||
logger.info(
|
||||
f"Removed user profile migrator {name}: {migrator}"
|
||||
)
|
||||
|
||||
@log_wrap(action="profile migration")
|
||||
async def migrate_profile(self, source_profile, target_profile) -> list[str]:
|
||||
logger.info(
|
||||
f"Beginning user profile migration from {source_profile!r} to {target_profile!r}"
|
||||
)
|
||||
results = []
|
||||
# Wrap this in a transaction so if something goes wrong with migration,
|
||||
# we roll back safely (although this may mess up caches)
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
for name, migrator in self._profile_migrators.items():
|
||||
try:
|
||||
result = await migrator(source_profile, target_profile)
|
||||
if result:
|
||||
results.append(result)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unexpected exception running user profile migrator {name} "
|
||||
f"migrating {source_profile!r} to {target_profile!r}."
|
||||
)
|
||||
raise
|
||||
|
||||
# Move all Discord and Twitch profile references over to the new profile
|
||||
discord_rows = await self.data.DiscordProfileRow.table.update_where(
|
||||
profileid=source_profile.profileid
|
||||
).set(profileid=target_profile.profileid)
|
||||
results.append(f"Migrated {len(discord_rows)} attached discord account(s).")
|
||||
|
||||
twitch_rows = await self.data.TwitchProfileRow.table.update_where(
|
||||
profileid=source_profile.profileid
|
||||
).set(profileid=target_profile.profileid)
|
||||
results.append(f"Migrated {len(twitch_rows)} attached twitch account(s).")
|
||||
|
||||
# And then mark the old profile as migrated
|
||||
await source_profile.profile_row.update(migrated=target_profile.profileid)
|
||||
results.append("Marking old profile as migrated.. finished!")
|
||||
return results
|
||||
|
||||
async def fetch_profile_by_id(self, profile_id: int) -> UserProfile:
|
||||
"""
|
||||
Fetch a UserProfile by the given id.
|
||||
"""
|
||||
return await UserProfile.fetch(self.bot, profile_id=profile_id)
|
||||
|
||||
async def fetch_profile_discord(self, user: discord.Member | discord.User) -> UserProfile:
|
||||
"""
|
||||
Fetch or create a UserProfile from the provided discord account.
|
||||
"""
|
||||
profile = await UserProfile.fetch_from_discordid(self.bot, user.id)
|
||||
if profile is None:
|
||||
profile = await UserProfile.create_from_discord(self.bot, user)
|
||||
return profile
|
||||
|
||||
async def fetch_profile_twitch(self, user: twitchio.User) -> UserProfile:
|
||||
"""
|
||||
Fetch or create a UserProfile from the provided twitch account.
|
||||
"""
|
||||
profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||
if profile is None:
|
||||
profile = await UserProfile.create_from_twitch(self.bot, user)
|
||||
return profile
|
||||
|
||||
# Community API
|
||||
def add_community_migrator(self, migrator, name=None):
|
||||
name = name or migrator.__name__
|
||||
self._comm_migrators[name or migrator.__name__] = migrator
|
||||
|
||||
logger.info(
|
||||
f"Added community migrator {name}: {migrator}"
|
||||
)
|
||||
return migrator
|
||||
|
||||
def del_community_migrator(self, name: str):
|
||||
migrator = self._comm_migrators.pop(name, None)
|
||||
|
||||
logger.info(
|
||||
f"Removed community migrator {name}: {migrator}"
|
||||
)
|
||||
|
||||
@log_wrap(action="community migration")
|
||||
async def migrate_community(self, source_comm, target_comm) -> list[str]:
|
||||
logger.info(
|
||||
f"Beginning community migration from {source_comm!r} to {target_comm!r}"
|
||||
)
|
||||
results = []
|
||||
# Wrap this in a transaction so if something goes wrong with migration,
|
||||
# we roll back safely (although this may mess up caches)
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
for name, migrator in self._comm_migrators.items():
|
||||
try:
|
||||
result = await migrator(source_comm, target_comm)
|
||||
if result:
|
||||
results.append(result)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unexpected exception running community migrator {name} "
|
||||
f"migrating {source_comm!r} to {target_comm!r}."
|
||||
)
|
||||
raise
|
||||
|
||||
# Move all Discord and Twitch community preferences over to the new profile
|
||||
discord_rows = await self.data.DiscordCommunityRow.table.update_where(
|
||||
profileid=source_comm.communityid
|
||||
).set(communityid=target_comm.communityid)
|
||||
results.append(f"Migrated {len(discord_rows)} attached discord guilds.")
|
||||
|
||||
twitch_rows = await self.data.TwitchCommunityRow.table.update_where(
|
||||
communityid=source_comm.communityid
|
||||
).set(communityid=target_comm.communityid)
|
||||
results.append(f"Migrated {len(twitch_rows)} attached twitch channel(s).")
|
||||
|
||||
# And then mark the old community as migrated
|
||||
await source_comm.update(migrated=target_comm.communityid)
|
||||
results.append("Marking old community as migrated.. finished!")
|
||||
return results
|
||||
|
||||
async def fetch_community_by_id(self, community_id: int) -> Community:
|
||||
"""
|
||||
Fetch a Community by the given id.
|
||||
"""
|
||||
return await Community.fetch(self.bot, community_id=community_id)
|
||||
|
||||
async def fetch_community_discord(self, guild: discord.Guild) -> Community:
|
||||
"""
|
||||
Fetch or create a Community from the provided discord guild.
|
||||
"""
|
||||
comm = await Community.fetch_from_discordid(self.bot, guild.id)
|
||||
if comm is None:
|
||||
comm = await Community.create_from_discord(self.bot, guild)
|
||||
return comm
|
||||
|
||||
async def fetch_community_twitch(self, user: twitchio.User) -> Community:
|
||||
"""
|
||||
Fetch or create a Community from the provided twitch account.
|
||||
"""
|
||||
community = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||
if community is None:
|
||||
community = await Community.create_from_twitch(self.bot, user)
|
||||
return community
|
||||
|
||||
# ----- Profile Commands -----
|
||||
@cmds.hybrid_group(
|
||||
name='profiles',
|
||||
description="Base comand group for user profiles."
|
||||
)
|
||||
async def profiles_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@profiles_grp.group(
|
||||
name='link',
|
||||
description="Base command group for linking profiles"
|
||||
)
|
||||
async def profiles_link_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@profiles_link_grp.command(
|
||||
name='twitch',
|
||||
description="Link a twitch account to your current profile."
|
||||
)
|
||||
async def profiles_link_twitch_cmd(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Ask the user to go through auth to get their userid
|
||||
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||
flow = await auth_cog.start_auth()
|
||||
message = await ctx.reply(
|
||||
f"Please [click here]({flow.auth.return_auth_url()}) to link your profile "
|
||||
"to Twitch."
|
||||
)
|
||||
authrow = await flow.run()
|
||||
await message.edit(
|
||||
content="Authentication Complete! Beginning profile merge..."
|
||||
)
|
||||
|
||||
results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||
if not results:
|
||||
logger.error(
|
||||
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
user = results[0]
|
||||
|
||||
# Retrieve author's profile if it exists
|
||||
author_profile = await UserProfile.fetch_from_discordid(self.bot, ctx.author.id)
|
||||
|
||||
# Check if the twitch-side user has a profile
|
||||
source_profile = await UserProfile.fetch_from_twitchid(self.bot, user.id)
|
||||
|
||||
if author_profile and source_profile is None:
|
||||
# All we need to do is attach the twitch row
|
||||
await author_profile.attach_twitch(user)
|
||||
await message.edit(
|
||||
content=f"Successfully added Twitch account **{user.name}**! There was no profile data to merge."
|
||||
)
|
||||
elif source_profile and author_profile is None:
|
||||
# Attach the discord row to the profile
|
||||
await source_profile.attach_discord(ctx.author)
|
||||
await message.edit(
|
||||
content=f"Successfully connected to Twitch profile **{user.name}**! There was no profile data to merge."
|
||||
)
|
||||
elif source_profile is None and author_profile is None:
|
||||
profile = await UserProfile.create_from_discord(self.bot, ctx.author)
|
||||
await profile.attach_twitch(user)
|
||||
|
||||
await message.edit(
|
||||
content=f"Opened a new user profile for you and linked Twitch account **{user.name}**."
|
||||
)
|
||||
elif author_profile.profileid == source_profile.profileid:
|
||||
await message.edit(
|
||||
content=f"The Twitch account **{user.name}** is already linked to your profile!"
|
||||
)
|
||||
else:
|
||||
# Migrate the existing profile data to the new profiles
|
||||
try:
|
||||
results = await self.migrate_profile(source_profile, author_profile)
|
||||
except Exception:
|
||||
await ctx.error_reply(
|
||||
"An issue was encountered while merging your account profiles!\n"
|
||||
"Migration rolled back, no data has been lost.\n"
|
||||
"The developer has been notified. Please try again later!"
|
||||
)
|
||||
raise
|
||||
|
||||
content = '\n'.join((
|
||||
"## Connecting Twitch account and merging profiles...",
|
||||
*results,
|
||||
"**Successfully linked account and merged profile data!**"
|
||||
))
|
||||
await message.edit(content=content)
|
||||
|
||||
# ----- Community Commands -----
|
||||
@cmds.hybrid_group(
|
||||
name='community',
|
||||
description="Base comand group for community profiles."
|
||||
)
|
||||
async def community_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@community_grp.group(
|
||||
name='link',
|
||||
description="Base command group for linking communities"
|
||||
)
|
||||
async def community_link_grp(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@community_link_grp.command(
|
||||
name='twitch',
|
||||
description="Link a twitch account to this community."
|
||||
)
|
||||
@appcmds.guild_only()
|
||||
@appcmds.default_permissions(manage_guild=True)
|
||||
async def comm_link_twitch_cmd(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
assert ctx.guild is not None
|
||||
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
|
||||
if not ctx.author.guild_permissions.manage_guild:
|
||||
await ctx.error_reply("You need the `MANAGE_GUILD` permission to link this guild to a community.")
|
||||
return
|
||||
|
||||
# Ask the user to go through auth to get their userid
|
||||
auth_cog = self.bot.get_cog('TwitchAuthCog')
|
||||
flow = await auth_cog.start_auth(
|
||||
scopes=[
|
||||
AuthScope.CHAT_EDIT,
|
||||
AuthScope.CHAT_READ,
|
||||
AuthScope.MODERATION_READ,
|
||||
AuthScope.CHANNEL_BOT,
|
||||
]
|
||||
)
|
||||
message = await ctx.reply(
|
||||
f"Please [click here]({flow.auth.return_auth_url()}) to link your Twitch channel to this server."
|
||||
)
|
||||
authrow = await flow.run()
|
||||
await message.edit(
|
||||
content="Authentication Complete! Beginning community profile merge..."
|
||||
)
|
||||
|
||||
results = await self.crocbot.fetch_users(ids=[authrow.userid])
|
||||
if not results:
|
||||
logger.error(
|
||||
f"User {authrow} obtained from Twitch authentication does not exist."
|
||||
)
|
||||
await ctx.error_reply("Sorry, something went wrong. Please try again later!")
|
||||
return
|
||||
|
||||
user = results[0]
|
||||
|
||||
# Retrieve author's profile if it exists
|
||||
guild_comm = await Community.fetch_from_discordid(self.bot, ctx.guild.id)
|
||||
|
||||
# Check if the twitch-side user has a profile
|
||||
twitch_comm = await Community.fetch_from_twitchid(self.bot, user.id)
|
||||
|
||||
if guild_comm and twitch_comm is None:
|
||||
# All we need to do is attach the twitch row
|
||||
await guild_comm.attach_twitch(user)
|
||||
await message.edit(
|
||||
content=f"Successfully linked Twitch channel **{user.name}**! There was no community data to merge."
|
||||
)
|
||||
elif twitch_comm and guild_comm is None:
|
||||
# Attach the discord row to the profile
|
||||
await twitch_comm.attach_discord(ctx.guild)
|
||||
await message.edit(
|
||||
content=f"Successfully connected to Twitch channel **{user.name}**!"
|
||||
)
|
||||
elif twitch_comm is None and guild_comm is None:
|
||||
profile = await Community.create_from_discord(self.bot, ctx.guild)
|
||||
await profile.attach_twitch(user)
|
||||
|
||||
await message.edit(
|
||||
content=f"Created a new community for this server and linked Twitch account **{user.name}**."
|
||||
)
|
||||
elif guild_comm.communityid == twitch_comm.communityid:
|
||||
await message.edit(
|
||||
content=f"This server is already linked to the Twitch channel **{user.name}**!"
|
||||
)
|
||||
else:
|
||||
# Migrate the existing profile data to the new profiles
|
||||
try:
|
||||
results = await self.migrate_community(twitch_comm, guild_comm)
|
||||
except Exception:
|
||||
await ctx.error_reply(
|
||||
"An issue was encountered while merging your community profiles!\n"
|
||||
"Migration rolled back, no data has been lost.\n"
|
||||
"The developer has been notified. Please try again later!"
|
||||
)
|
||||
raise
|
||||
|
||||
content = '\n'.join((
|
||||
"## Connecting Twitch account and merging community profiles...",
|
||||
*results,
|
||||
"**Successfully linked account and merged community data!**"
|
||||
))
|
||||
await message.edit(content=content)
|
||||
123
src/modules/profiles/community.py
Normal file
123
src/modules/profiles/community.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Optional, Self
|
||||
|
||||
import discord
|
||||
import twitchio
|
||||
|
||||
from meta import LionBot
|
||||
from utils.lib import utc_now
|
||||
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
|
||||
|
||||
|
||||
class Community:
|
||||
def __init__(self, bot: LionBot, community_row):
|
||||
self.bot = bot
|
||||
self.row: ProfileData.CommunityRow = community_row
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
return self.bot.get_cog('ProfileCog')
|
||||
|
||||
@property
|
||||
def data(self) -> ProfileData:
|
||||
return self.cog.data
|
||||
|
||||
@property
|
||||
def communityid(self):
|
||||
return self.row.communityid
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Community communityid={self.communityid} row={self.row}>"
|
||||
|
||||
async def attach_discord(self, guild: discord.Guild):
|
||||
"""
|
||||
Attach a new discord guild to this community.
|
||||
Assumes the discord guild is not already associated to a community.
|
||||
"""
|
||||
discord_row = await self.data.DiscordCommunityRow.create(
|
||||
communityid=self.communityid,
|
||||
guildid=guild.id
|
||||
)
|
||||
logger.info(
|
||||
f"Attached discord guild {guild!r} to community {self!r}"
|
||||
)
|
||||
return discord_row
|
||||
|
||||
async def attach_twitch(self, user: twitchio.User):
|
||||
"""
|
||||
Attach a new Twitch user channel to this community.
|
||||
"""
|
||||
twitch_row = await self.data.TwitchCommunityRow.create(
|
||||
communityid=self.communityid,
|
||||
channelid=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
f"Attached twitch channel {user!r} to community {self!r}"
|
||||
)
|
||||
return twitch_row
|
||||
|
||||
async def discord_guilds(self) -> list[ProfileData.DiscordCommunityRow]:
|
||||
"""
|
||||
Fetch the Discord guild rows associated to this community.
|
||||
"""
|
||||
return await self.data.DiscordCommunityRow.fetch_where(communityid=self.communityid)
|
||||
|
||||
async def twitch_channels(self) -> list[ProfileData.TwitchCommunityRow]:
|
||||
"""
|
||||
Fetch the Twitch user rows associated to this profile.
|
||||
"""
|
||||
return await self.data.TwitchCommunityRow.fetch_where(communityid=self.communityid)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, bot: LionBot, community_id: int) -> Self:
|
||||
community_row = await bot.get_cog('ProfileCog').data.CommunityRow.fetch(community_id)
|
||||
if community_row is None:
|
||||
raise ValueError("Provided community_id does not exist.")
|
||||
return cls(bot, community_row)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_twitchid(cls, bot: LionBot, channelid: int | str) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.TwitchCommunityRow.fetch_where(channelid=str(channelid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].communityid)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_discordid(cls, bot: LionBot, guildid: int) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.DiscordCommunityRow.fetch_where(guildid=guildid)
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].communityid)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new empty community with the given initial arguments.
|
||||
|
||||
Communities should usually be created using `create_from_discord` or `create_from_twitch`
|
||||
to correctly setup initial preferences (e.g. name, avatar).
|
||||
"""
|
||||
# Create a new community
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
row = await data.CommunityRow.create(created_at=utc_now(), **kwargs)
|
||||
return await cls.fetch(bot, row.communityid)
|
||||
|
||||
@classmethod
|
||||
async def create_from_discord(cls, bot: LionBot, guild: discord.Guild, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new community using the given Discord guild as a base.
|
||||
"""
|
||||
self = await cls.create(bot, **kwargs)
|
||||
await self.attach_discord(guild)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Twitch channel user as a base.
|
||||
"""
|
||||
self = await cls.create(bot, **kwargs)
|
||||
await self.attach_twitch(user)
|
||||
return self
|
||||
158
src/modules/profiles/data.py
Normal file
158
src/modules/profiles/data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class ProfileData(Registry):
|
||||
class UserProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_profiles(
|
||||
profileid SERIAL PRIMARY KEY,
|
||||
nickname TEXT,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_profiles'
|
||||
_cache_ = {}
|
||||
|
||||
profileid = Integer(primary=True)
|
||||
nickname = String()
|
||||
migrated = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
|
||||
class DiscordProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE profiles_discord(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_discord_profileid ON profiles_discord (profileid);
|
||||
CREATE UNIQUE INDEX profiles_discord_userid ON profiles_discord (userid);
|
||||
"""
|
||||
_tablename_ = 'profiles_discord'
|
||||
_cache_ = {}
|
||||
|
||||
linkid = Integer(primary=True)
|
||||
profileid = Integer()
|
||||
userid = Integer()
|
||||
created_at = Integer()
|
||||
|
||||
@classmethod
|
||||
async def fetch_profile(cls, profileid: int):
|
||||
rows = await cls.fetch_where(profiled=profileid)
|
||||
return next(rows, None)
|
||||
|
||||
|
||||
class TwitchProfileRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE profiles_twitch(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
userid TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX profiles_twitch_profileid ON profiles_twitch (profileid);
|
||||
CREATE UNIQUE INDEX profiles_twitch_userid ON profiles_twitch (userid);
|
||||
"""
|
||||
_tablename_ = 'profiles_twitch'
|
||||
_cache_ = {}
|
||||
|
||||
linkid = Integer(primary=True)
|
||||
profileid = Integer()
|
||||
userid = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_profile(cls, profileid: int):
|
||||
rows = await cls.fetch_where(profiled=profileid)
|
||||
return next(rows, None)
|
||||
|
||||
class CommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities(
|
||||
communityid SERIAL PRIMARY KEY,
|
||||
migrated INTEGER REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities'
|
||||
_cache_ = {}
|
||||
|
||||
communityid = Integer(primary=True)
|
||||
migrated = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
class DiscordCommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities_discord(
|
||||
guildid BIGINT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities_discord'
|
||||
_cache_ = {}
|
||||
|
||||
guildid = Integer(primary=True)
|
||||
communityid = Integer()
|
||||
linked_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_community(cls, communityid: int):
|
||||
rows = await cls.fetch_where(communityd=communityid)
|
||||
return next(rows, None)
|
||||
|
||||
class TwitchCommunityRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE communities_twitch(
|
||||
channelid TEXT PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
linked_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities_twitch'
|
||||
_cache_ = {}
|
||||
|
||||
channelid = String(primary=True)
|
||||
communityid = Integer()
|
||||
linked_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def fetch_community(cls, communityid: int):
|
||||
rows = await cls.fetch_where(communityd=communityid)
|
||||
return next(rows, None)
|
||||
|
||||
class CommunityMemberRow(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE community_members(
|
||||
memberid SERIAL PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities (communityid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
profileid INTEGER NOT NULL REFERENCES user_profiles (profileid) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX community_members_communityid_profileid ON community_members (communityid, profileid);
|
||||
"""
|
||||
_tablename_ = 'community_members'
|
||||
_cache_ = {}
|
||||
|
||||
memberid = Integer(primary=True)
|
||||
communityid = Integer()
|
||||
profileid = Integer()
|
||||
created_at = Timestamp()
|
||||
156
src/modules/profiles/profile.py
Normal file
156
src/modules/profiles/profile.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Optional, Self
|
||||
|
||||
import discord
|
||||
import twitchio
|
||||
|
||||
from meta import LionBot
|
||||
from utils.lib import utc_now
|
||||
|
||||
from . import logger
|
||||
from .data import ProfileData
|
||||
|
||||
|
||||
|
||||
class UserProfile:
|
||||
def __init__(self, bot: LionBot, profile_row):
|
||||
self.bot = bot
|
||||
self.profile_row: ProfileData.UserProfileRow = profile_row
|
||||
|
||||
@property
|
||||
def cog(self):
|
||||
return self.bot.get_cog('ProfileCog')
|
||||
|
||||
@property
|
||||
def data(self) -> ProfileData:
|
||||
return self.cog.data
|
||||
|
||||
@property
|
||||
def profileid(self):
|
||||
return self.profile_row.profileid
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
|
||||
|
||||
async def get_name(self):
|
||||
# TODO: Store a preferred name in the profile preferences
|
||||
# TODO Should have a multi-fetch system
|
||||
name = None
|
||||
twitches = await self.twitch_accounts()
|
||||
if twitches:
|
||||
users = await self.bot.crocbot.fetch_users(
|
||||
ids=[int(twitches[0].userid)]
|
||||
)
|
||||
if users:
|
||||
user = users[0]
|
||||
name = user.display_name
|
||||
|
||||
if not name:
|
||||
discords = await self.discord_accounts()
|
||||
if discords:
|
||||
user = await self.bot.fetch_user(discords[0].userid)
|
||||
name = user.display_name
|
||||
|
||||
if not name:
|
||||
name = 'Unknown'
|
||||
|
||||
return name
|
||||
|
||||
async def attach_discord(self, user: discord.User | discord.Member):
|
||||
"""
|
||||
Attach a new discord user to this profile.
|
||||
Assumes the discord user does not itself have a profile.
|
||||
"""
|
||||
discord_row = await self.data.DiscordProfileRow.create(
|
||||
profileid=self.profileid,
|
||||
userid=user.id
|
||||
)
|
||||
logger.info(
|
||||
f"Attached discord user {user!r} to profile {self!r}"
|
||||
)
|
||||
return discord_row
|
||||
|
||||
async def attach_twitch(self, user: twitchio.User):
|
||||
"""
|
||||
Attach a new Twitch user to this profile.
|
||||
"""
|
||||
twitch_row = await self.data.TwitchProfileRow.create(
|
||||
profileid=self.profileid,
|
||||
userid=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
f"Attached twitch user {user!r} to profile {self!r}"
|
||||
)
|
||||
return twitch_row
|
||||
|
||||
async def discord_accounts(self) -> list[ProfileData.DiscordProfileRow]:
|
||||
"""
|
||||
Fetch the Discord accounts associated to this profile.
|
||||
"""
|
||||
return await self.data.DiscordProfileRow.fetch_where(
|
||||
profileid=self.profileid
|
||||
).order_by(
|
||||
'created_at'
|
||||
)
|
||||
|
||||
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
|
||||
"""
|
||||
Fetch the Twitch accounts associated to this profile.
|
||||
"""
|
||||
return await self.data.TwitchProfileRow.fetch_where(
|
||||
profileid=self.profileid
|
||||
).order_by(
|
||||
'created_at'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, bot: LionBot, profile_id: int) -> Self:
|
||||
profile_row = await bot.get_cog('ProfileCog').data.UserProfileRow.fetch(profile_id)
|
||||
if profile_row is None:
|
||||
raise ValueError("Provided profile_id does not exist.")
|
||||
return cls(bot, profile_row)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_twitchid(cls, bot: LionBot, userid: int | str) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.TwitchProfileRow.fetch_where(userid=str(userid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].profileid)
|
||||
|
||||
@classmethod
|
||||
async def fetch_from_discordid(cls, bot: LionBot, userid: int) -> Optional[Self]:
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
rows = await data.DiscordProfileRow.fetch_where(userid=(userid))
|
||||
if rows:
|
||||
return await cls.fetch(bot, rows[0].profileid)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, bot: LionBot, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new empty profile with the given initial arguments.
|
||||
|
||||
Profiles should usually be created using `create_from_discord` or `create_from_twitch`
|
||||
to correctly setup initial profile preferences (e.g. name, avatar).
|
||||
"""
|
||||
# Create a new profile
|
||||
data = bot.get_cog('ProfileCog').data
|
||||
profile_row = await data.UserProfileRow.create(created_at=utc_now())
|
||||
profile = await cls.fetch(bot, profile_row.profileid)
|
||||
return profile
|
||||
|
||||
@classmethod
|
||||
async def create_from_discord(cls, bot: LionBot, user: discord.Member | discord.User, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Discord user as a base.
|
||||
"""
|
||||
profile = await cls.create(bot, **kwargs)
|
||||
await profile.attach_discord(user)
|
||||
return profile
|
||||
|
||||
@classmethod
|
||||
async def create_from_twitch(cls, bot: LionBot, user: twitchio.User, **kwargs) -> Self:
|
||||
"""
|
||||
Create a new profile using the given Twitch user as a base.
|
||||
"""
|
||||
profile = await cls.create(bot, **kwargs)
|
||||
await profile.attach_twitch(user)
|
||||
return profile
|
||||
@@ -17,9 +17,19 @@ class ShoutoutCog(LionCog):
|
||||
and drop a follow! \
|
||||
They {areorwere} streaming {game} at {channel}
|
||||
"""
|
||||
COWO_SHOUTOUT = """
|
||||
We think that {name} is a great coworker and you should check them out for more productive vibes! \
|
||||
They {areorwere} streaming {game} at {channel}
|
||||
"""
|
||||
ART_SHOUTOUT = """
|
||||
We think that {name} is an awesome artist and you should check them out for cool art and cosy vibes! \
|
||||
They {areorwere} streaming {game} at {channel}
|
||||
"""
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot = bot.crocbot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
|
||||
self.data = bot.db.load_registry(ShoutoutData())
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
@@ -59,19 +69,28 @@ class ShoutoutCog(LionCog):
|
||||
return replace_multiple(text, mapping)
|
||||
|
||||
@commands.command(aliases=['so'])
|
||||
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
|
||||
async def shoutout(self, ctx: commands.Context, target: str, typ: Optional[str]=None):
|
||||
# Make sure caller is mod/broadcaster
|
||||
# Lookup custom shoutout for this user
|
||||
# If it exists use it, otherwise use default shoutout
|
||||
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
data = await self.data.CustomShoutout.fetch(int(user.id))
|
||||
if data:
|
||||
shoutout = data.content
|
||||
user = await self.crocbot.seek_user(target)
|
||||
if user is None:
|
||||
await ctx.reply(f"Couldn't resolve '{target}' to a valid user.")
|
||||
else:
|
||||
shoutout = self.DEFAULT_SHOUTOUT
|
||||
formatted = await self.format_shoutout(shoutout, user)
|
||||
await ctx.reply(formatted)
|
||||
data = await self.data.CustomShoutout.fetch(int(user.id))
|
||||
if data:
|
||||
shoutout = data.content
|
||||
elif typ == 'cowo':
|
||||
shoutout = self.COWO_SHOUTOUT
|
||||
elif typ == 'art':
|
||||
shoutout = self.ART_SHOUTOUT
|
||||
else:
|
||||
shoutout = self.DEFAULT_SHOUTOUT
|
||||
formatted = await self.format_shoutout(shoutout, user)
|
||||
await ctx.reply(formatted)
|
||||
# TODO: How to /shoutout with lib?
|
||||
# TODO Shoutout queue
|
||||
|
||||
@commands.command()
|
||||
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):
|
||||
|
||||
@@ -8,9 +8,16 @@ async def get_leaderboard_card(
|
||||
bot: LionBot, highlightid: int, guildid: int,
|
||||
mode: CardMode,
|
||||
entry_data: list[tuple[int, int, int]], # userid, position, time
|
||||
name_map: dict[int, str] = {},
|
||||
extra_skin_args = {},
|
||||
):
|
||||
"""
|
||||
Render a leaderboard card with given parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name_map: dict[int, str]
|
||||
Map of userid -> name, used first before cache or fetch.
|
||||
"""
|
||||
guild = bot.get_guild(guildid)
|
||||
if guild is None:
|
||||
@@ -20,8 +27,12 @@ async def get_leaderboard_card(
|
||||
avatars = {}
|
||||
names = {}
|
||||
missing = []
|
||||
|
||||
for userid, _, _ in entry_data:
|
||||
if guild and (member := guild.get_member(userid)):
|
||||
if (name := name_map.get(userid, None)):
|
||||
avatars[userid] = None
|
||||
names[userid] = name
|
||||
elif guild and (member := guild.get_member(userid)):
|
||||
avatars[userid] = member.avatar.key if member.avatar else None
|
||||
names[userid] = member.display_name
|
||||
elif (user := bot.get_user(userid)):
|
||||
@@ -65,7 +76,7 @@ async def get_leaderboard_card(
|
||||
guildid, None, LeaderboardCard.card_id
|
||||
)
|
||||
card = LeaderboardCard(
|
||||
skin=skin | {'mode': mode},
|
||||
skin=skin | {'mode': mode} | extra_skin_args,
|
||||
server_name=guild.name,
|
||||
entries=entries,
|
||||
highlight=highlight
|
||||
|
||||
8
src/modules/time/__init__.py
Normal file
8
src/modules/time/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import TimeCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TimeCog(bot))
|
||||
110
src/modules/time/cog.py
Normal file
110
src/modules/time/cog.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import datetime as dt
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||
from utils.lib import strfdelta, utc_now, parse_dur
|
||||
|
||||
from . import logger
|
||||
|
||||
|
||||
class TimeCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
|
||||
async def cog_load(self):
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
|
||||
async def cog_unload(self):
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
async def get_timezone_for(self, profile):
|
||||
timezone = None
|
||||
discords = await profile.discord_accounts()
|
||||
if discords:
|
||||
userid = discords[0].userid
|
||||
luser = await self.bot.core.lions.fetch_user(userid)
|
||||
if luser:
|
||||
timezone = luser.config.timezone.value
|
||||
return timezone
|
||||
|
||||
def get_timestr(self, tz, brief=False):
|
||||
"""
|
||||
Get the current time in the given timezone, using a fixed format string.
|
||||
"""
|
||||
format_str = "%H:%M, %d/%m/%Y" if brief else "%I:%M %p (%Z) on %a, %d/%m/%Y"
|
||||
now = dt.datetime.now(tz=tz)
|
||||
return now.strftime(format_str)
|
||||
|
||||
async def time_diff(self, tz, auth_tz, name, brief=False):
|
||||
"""
|
||||
Get a string representing the time difference between the user's timezone and the given one.
|
||||
"""
|
||||
if auth_tz is None or tz is None:
|
||||
return None
|
||||
author_time = dt.datetime.now(tz=auth_tz)
|
||||
other_time = dt.datetime.now(tz=tz)
|
||||
timediff = other_time.replace(tzinfo=None) - author_time.replace(tzinfo=None)
|
||||
diffsecs = round(timediff.total_seconds())
|
||||
|
||||
if diffsecs == 0:
|
||||
return ", the same as {}!".format(name)
|
||||
|
||||
modifier = "behind" if diffsecs > 0 else "ahead"
|
||||
diffsecs = abs(diffsecs)
|
||||
|
||||
hours, remainder = divmod(diffsecs, 3600)
|
||||
mins, _ = divmod(remainder, 60)
|
||||
|
||||
hourstr = "{} hour{} ".format(hours, "s" if hours > 1 else "") if hours else ""
|
||||
minstr = "{} minutes ".format(mins) if mins else ""
|
||||
joiner = "and " if (hourstr and minstr) else ""
|
||||
return ". {} is {}{}{}{}, at {}.".format(
|
||||
name, hourstr, joiner, minstr, modifier, self.get_timestr(auth_tz, brief=brief)
|
||||
)
|
||||
|
||||
@commands.command(name='time', aliases=['ti'])
|
||||
async def time_cmd(self, ctx, *, args: str=''):
|
||||
"""
|
||||
Current usage is
|
||||
!time
|
||||
!time <target user>
|
||||
|
||||
Planned:
|
||||
!time set ...
|
||||
!time at ...
|
||||
"""
|
||||
authprofile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
authtz = await self.get_timezone_for(authprofile)
|
||||
|
||||
if args:
|
||||
target_tw = await self.crocbot.seek_user(args)
|
||||
if target_tw is None:
|
||||
return await ctx.reply(f"Couldn't find user '{args}'!")
|
||||
target = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(target_tw)
|
||||
targettz = await self.get_timezone_for(target)
|
||||
name = await target.get_name()
|
||||
if targettz is None:
|
||||
return await ctx.reply(
|
||||
f"{name} hasn't set their timezone! Ask them to set it with '/my timezone' on discord."
|
||||
)
|
||||
else:
|
||||
target = None
|
||||
targettz = None
|
||||
name = None
|
||||
if authtz is None:
|
||||
return await ctx.reply(
|
||||
"You haven't set your timezone! Set it on discord by linking your Twitch account with `/profiles link twitch`, and then using `/my timezone`"
|
||||
)
|
||||
|
||||
timestr = self.get_timestr(targettz if target else authtz)
|
||||
name = name or await authprofile.get_name()
|
||||
|
||||
if target:
|
||||
tdiffstr = await self.time_diff(targettz, authtz, await authprofile.get_name())
|
||||
msg = f"The current time for {name} is {timestr}{tdiffstr}"
|
||||
else:
|
||||
msg = f"The current time for {name} is {timestr}"
|
||||
await ctx.reply(msg)
|
||||
8
src/modules/twreminders/__init__.py
Normal file
8
src/modules/twreminders/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import ReminderCog
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(ReminderCog(bot))
|
||||
369
src/modules/twreminders/cog.py
Normal file
369
src/modules/twreminders/cog.py
Normal file
@@ -0,0 +1,369 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import itertools
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
from dateutil.parser import ParserError, parse
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
import datetime as dt
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from meta import CrocBot, LionCog, LionContext, LionBot
|
||||
from utils.lib import strfdelta, utc_now, parse_dur
|
||||
from . import logger
|
||||
|
||||
|
||||
reminder_regex = re.compile(
|
||||
r"""
|
||||
(^)?(?P<type> (?: \b in) | (?: every) | (?P<at> at))
|
||||
\s*
|
||||
(?(at) (?P<time> \d?\d (?: :\d\d)?\s*(?: am | pm)?) | (?P<duration> (?: day| hour| (?:\d+\s*(?:(?:d|h|m|s)[a-zA-Z]*)?(?:\s|and)*)+)))
|
||||
(?:(?(1) (?:, | ; | : | \. | to)?\s+ | $ ))
|
||||
""",
|
||||
re.IGNORECASE | re.VERBOSE | re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reminder:
|
||||
userid: int
|
||||
content: str
|
||||
name: str
|
||||
channel: str
|
||||
remind_at: datetime
|
||||
|
||||
|
||||
class ReminderCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.crocbot: CrocBot = bot.crocbot
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
self.reminders: dict[int, list[Reminder]] = defaultdict(list)
|
||||
|
||||
self.next_reminder_task = None
|
||||
self._reminder_wait_task = None
|
||||
self.reminder_lock = asyncio.Lock()
|
||||
|
||||
async def cog_load(self):
|
||||
await self.load_reminders()
|
||||
self._load_twitch_methods(self.crocbot)
|
||||
self.loaded.set()
|
||||
|
||||
async def ensure_loaded(self):
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
async def cog_unload(self):
|
||||
self._unload_twitch_methods(self.crocbot)
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
await self.ensure_loaded()
|
||||
return True
|
||||
|
||||
def save_reminders(self):
|
||||
with open('reminders.json', 'w', encoding='utf-8') as f:
|
||||
mapped = {
|
||||
int(userid): [
|
||||
{
|
||||
'userid': int(state.userid),
|
||||
'name': state.name,
|
||||
'channel': state.channel,
|
||||
'content': state.content,
|
||||
'remind_at': state.remind_at.isoformat(),
|
||||
}
|
||||
for state in states
|
||||
]
|
||||
for userid, states in self.reminders.items()
|
||||
}
|
||||
json.dump(mapped, f, ensure_ascii=False, indent=4)
|
||||
|
||||
async def load_reminders(self):
|
||||
if self.next_reminder_task and not self.next_reminder_task.cancelled():
|
||||
self.next_reminder_task.cancel()
|
||||
self.next_reminder_task = None
|
||||
|
||||
with open('reminders.json') as f:
|
||||
mapped = json.load(f)
|
||||
self.reminders.clear()
|
||||
for userid, states in mapped.items():
|
||||
userid = int(userid)
|
||||
for map in states:
|
||||
reminder = Reminder(
|
||||
userid=int(map['userid']),
|
||||
content=map['content'],
|
||||
name=map['name'],
|
||||
channel=map['channel'],
|
||||
remind_at=dt.datetime.fromisoformat(map['remind_at'])
|
||||
)
|
||||
self.reminders[userid].append(reminder)
|
||||
self.schedule_next_reminder()
|
||||
logger.info(f"Loaded reminders: {self.reminders}")
|
||||
|
||||
def schedule_next_reminder(self):
|
||||
"""
|
||||
Schedule the next reminder in the queue, if it exists, and return it.
|
||||
Cancels any currently running task.
|
||||
"""
|
||||
if not self.reminders:
|
||||
return None
|
||||
next_reminder = min(
|
||||
itertools.chain(*self.reminders.values()), key=lambda r: r.remind_at, default=None
|
||||
)
|
||||
if next_reminder:
|
||||
self.next_reminder_task = asyncio.create_task(self.run_reminder(next_reminder))
|
||||
else:
|
||||
# We still need to cancel any ongoing reminders
|
||||
if self._reminder_wait_task and not self._reminder_wait_task.cancelled():
|
||||
self._reminder_wait_task.cancel()
|
||||
|
||||
async def run_reminder(self, reminder: Reminder):
|
||||
"""
|
||||
Wait for and then run the given reminder.
|
||||
Expects to be cancelled if another reminder is scheduled earlier.
|
||||
"""
|
||||
# Cancel the next reminder wait task.
|
||||
# If the next reminder is currently executing/firing,
|
||||
# this will do nothing and we will wait until it is finished.
|
||||
if self._reminder_wait_task and not self._reminder_wait_task.cancelled():
|
||||
self._reminder_wait_task.cancel()
|
||||
|
||||
# This ensures that only one reminder task runs at once
|
||||
async with self.reminder_lock:
|
||||
now = utc_now()
|
||||
to_wait = (reminder.remind_at - now).total_seconds()
|
||||
try:
|
||||
self._reminder_wait_task = asyncio.create_task(asyncio.sleep(to_wait))
|
||||
await self._reminder_wait_task
|
||||
except asyncio.CancelledError:
|
||||
# Reminder task was cancelled
|
||||
raise
|
||||
|
||||
# Now fire the reminder
|
||||
await self.fire_reminder(reminder)
|
||||
|
||||
# And schedule the next reminder if needed
|
||||
self.schedule_next_reminder()
|
||||
|
||||
async def fire_reminder(self, reminder: Reminder):
|
||||
"""
|
||||
Actually run the given reminder.
|
||||
"""
|
||||
# Check that this reminder is still valid
|
||||
if reminder not in self.reminders[reminder.userid]:
|
||||
logger.error(f"Reminder {reminder!r} is firing but not scheduled!")
|
||||
return
|
||||
|
||||
# We don't want to reschedule while a reminder is running
|
||||
# Get the channel to send to
|
||||
destination = self.crocbot.get_channel(reminder.channel)
|
||||
if destination is None:
|
||||
logger.info(f"Reminder couldn't get channel '{reminder.channel}'. Trying again in a minute.")
|
||||
# In case we aren't actually ready yet
|
||||
await self.crocbot.wait_for_ready()
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Cancelling channel wait task for reminder.")
|
||||
raise
|
||||
destination = self.crocbot.get_channel(reminder.channel)
|
||||
if destination is None:
|
||||
# This means we haven't joined the channel
|
||||
logger.warning(f"Reminder couldn't get channel '{reminder.channel}' for the second time. Cancelling.")
|
||||
else:
|
||||
logger.info(f"Channel '{reminder.channel}' found as {destination}. Continuing.")
|
||||
|
||||
if destination is not None:
|
||||
# Send the reminder
|
||||
msg = f"@{reminder.name}, you asked me to remind you: {reminder.content}"
|
||||
await destination.send(msg)
|
||||
|
||||
# This should really be based on a reminderid but oh well
|
||||
# It's theoretically possible for a reminder to be scheduled at the same time as it is run
|
||||
# In which case the wrong reminder will be removed.
|
||||
self.reminders[reminder.userid].remove(reminder)
|
||||
self.save_reminders()
|
||||
|
||||
def get_reminders_for(self, userid: int):
|
||||
return self.reminders.get(userid, [])
|
||||
|
||||
@commands.command(name='remindme', aliases=['reminders', 'reminder'])
|
||||
async def remindme_cmd(self, ctx, *, args: str=''):
|
||||
args = args.strip()
|
||||
userid = int(ctx.author.id)
|
||||
existing = self.get_reminders_for(userid)
|
||||
existing.sort(key=lambda r: r.remind_at, reverse=False)
|
||||
now = utc_now()
|
||||
|
||||
if not args or args.lower() in ('show', 'list'):
|
||||
# Show user's current reminders or show usage
|
||||
if not existing:
|
||||
await ctx.reply(
|
||||
"USAGE: !remindme <task> in <dur> EG: !remindme Coffee is ready in 10m | !remindme in 10m, Coffee is ready"
|
||||
)
|
||||
elif len(existing) == 1:
|
||||
reminder = existing[0]
|
||||
dur = reminder.remind_at - now
|
||||
sec = (dur.total_seconds()) < 60
|
||||
formatted_dur = strfdelta(dur, short=False, sec=sec)
|
||||
await ctx.reply(
|
||||
f"I will remind you about '{reminder.content}' in about {formatted_dur}. Use !remindme cancel to cancel!"
|
||||
)
|
||||
else:
|
||||
parts = []
|
||||
for i, reminder in enumerate(existing, start=1):
|
||||
dur = reminder.remind_at - now
|
||||
sec = (dur.total_seconds()) < 60
|
||||
formatted_dur = strfdelta(dur, short=True, sec=sec)
|
||||
parts.append(
|
||||
f"{i}: '{reminder.content}' in {formatted_dur}"
|
||||
)
|
||||
remstr = '; '.join(parts)
|
||||
if len(remstr) > 290:
|
||||
remstr = remstr[:290] + '...'
|
||||
|
||||
await ctx.reply(
|
||||
f"Active Reminders: {remstr}. Use '!remindme cancel n' or '!remindme clear' to remove!"
|
||||
)
|
||||
elif args.lower() in ('clear', 'clearall', 'remove all'):
|
||||
# Remove all reminders
|
||||
if existing:
|
||||
self.reminders.pop(userid, None)
|
||||
self.save_reminders()
|
||||
self.schedule_next_reminder()
|
||||
else:
|
||||
await ctx.reply("You don't have any reminders set!")
|
||||
elif args.lower().split(maxsplit=1)[0] in ('remove', 'cancel'):
|
||||
splits = args.split(maxsplit=1)
|
||||
remaining = splits[1].strip() if len(splits) > 1 else ''
|
||||
|
||||
# Remove a specified reminder
|
||||
to_remove = None
|
||||
if not existing:
|
||||
await ctx.reply("You don't have any reminders set!")
|
||||
elif len(existing) == 1:
|
||||
to_remove = existing[0]
|
||||
elif remaining.isdigit():
|
||||
# Try to the remove the reminder with the give number
|
||||
given = int(remaining)
|
||||
if given > len(existing):
|
||||
await ctx.reply(f"You only have {len(existing)} reminders!")
|
||||
else:
|
||||
to_remove = existing[given - 1]
|
||||
else:
|
||||
# Invalid arguments, show usage
|
||||
await ctx.reply(
|
||||
"USAGE: !remindme cancel <number>, e.g. !remindme cancel 1 to cancel your first reminder!"
|
||||
)
|
||||
|
||||
if to_remove is not None:
|
||||
self.reminders[userid].remove(to_remove)
|
||||
await ctx.reply(
|
||||
f"Cancelled your reminder '{to_remove.content}'"
|
||||
)
|
||||
self.save_reminders()
|
||||
self.schedule_next_reminder()
|
||||
else:
|
||||
# Parse for reminder
|
||||
content = None
|
||||
duration = None
|
||||
repeating = None
|
||||
|
||||
# First parse it
|
||||
match = re.search(reminder_regex, args)
|
||||
if match:
|
||||
typ = match.group('type').lower().strip()
|
||||
content = (args[:match.start()] + args[match.end():]).strip()
|
||||
if typ in ('every', 'in'):
|
||||
repeating = typ == 'every'
|
||||
duration_str = match.group('duration').lower()
|
||||
if duration_str.isdigit():
|
||||
# Default to minutes if no unit given
|
||||
duration = int(duration_str) * 60
|
||||
elif duration_str in ('day', 'a day'):
|
||||
duration = 24 * 60 * 60
|
||||
elif duration_str in ('hour', 'an hour'):
|
||||
duration = 60 * 60
|
||||
else:
|
||||
duration = parse_dur(duration_str)
|
||||
|
||||
elif typ == 'at':
|
||||
# Get timezone for this member.
|
||||
profile = await self.bot.get_cog('ProfileCog').fetch_profile_twitch(ctx.author)
|
||||
timezone = None
|
||||
discords = await profile.discord_accounts()
|
||||
if discords:
|
||||
luserid = discords[0].userid
|
||||
luser = await self.bot.core.lions.fetch_user(luserid)
|
||||
if luser:
|
||||
timezone = luser.config.timezone.value
|
||||
if not timezone:
|
||||
return await ctx.reply(
|
||||
"Sorry, to use this you have to link your account with `/profiles link twitch` and set your timezone with '/my timezone' on the Discord!"
|
||||
)
|
||||
|
||||
time_str = match.group('time').lower()
|
||||
if time_str.isdigit():
|
||||
# Assume it's an hour
|
||||
time_str = time_str + ':00'
|
||||
default = dt.datetime.now(tz=timezone).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
try:
|
||||
ts = parse(time_str, fuzzy=True, default=default)
|
||||
except ParserError:
|
||||
return await ctx.reply(
|
||||
"Sorry, I didn't understand your target time! Please use e.g. !remindme Remember to hydrate at 10pm"
|
||||
)
|
||||
while ts < dt.datetime.now(tz=timezone):
|
||||
ts += dt.timedelta(days=1)
|
||||
|
||||
duration = (ts - dt.datetime.now(tz=timezone)).total_seconds()
|
||||
duration = int(duration)
|
||||
|
||||
if content.startswith('to '):
|
||||
content = content[3:].strip()
|
||||
else:
|
||||
# Legacy parsing, without requiring "in" at the front
|
||||
splits = args.split(maxsplit=1)
|
||||
if len(splits) == 2 and splits[0].isdigit():
|
||||
repeating = False
|
||||
duration = int(splits[0]) * 60
|
||||
content = splits[1].strip()
|
||||
|
||||
# Sanity checking
|
||||
if not duration or not content:
|
||||
return await ctx.reply(
|
||||
"Sorry, I didn't understand your reminder! Please use e.g. !remindme Coffee is ready in 10m"
|
||||
)
|
||||
if repeating:
|
||||
return await ctx.reply(
|
||||
"Sorry, we don't support repeating reminders right now!"
|
||||
)
|
||||
if len(existing) > 10:
|
||||
return await ctx.reply(
|
||||
"Sorry, you can only have 10 active reminders! Use !remindme cancel or !remindme clear to cancel some!"
|
||||
)
|
||||
|
||||
reminder = Reminder(
|
||||
userid=userid,
|
||||
content=content,
|
||||
name=ctx.author.name,
|
||||
channel=ctx.channel.name,
|
||||
remind_at=now + timedelta(seconds=duration)
|
||||
)
|
||||
|
||||
self.reminders[userid].append(reminder)
|
||||
dur = reminder.remind_at - now
|
||||
sec = (dur.total_seconds()) < 60
|
||||
formatted_dur = strfdelta(dur, short=False, sec=sec)
|
||||
|
||||
msg = f"Got it! I will remind you in {formatted_dur}!"
|
||||
|
||||
await ctx.reply(msg)
|
||||
|
||||
self.save_reminders()
|
||||
self.schedule_next_reminder()
|
||||
Submodule src/modules/voicefix updated: 5146e46515...cca1c94bd5
7
src/modules/voiceroles/__init__.py
Normal file
7
src/modules/voiceroles/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import VoiceRoleCog
|
||||
await bot.add_cog(VoiceRoleCog(bot))
|
||||
166
src/modules/voiceroles/cog.py
Normal file
166
src/modules/voiceroles/cog.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
from cachetools import FIFOCache
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import discord
|
||||
from discord.abc import GuildChannel
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.logger import log_wrap
|
||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||
from utils.ui import Confirm
|
||||
|
||||
from . import logger
|
||||
from .data import VoiceRoleData
|
||||
|
||||
|
||||
class VoiceRoleCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(VoiceRoleData())
|
||||
|
||||
self._event_locks: WeakValueDictionary[tuple[int, int], asyncio.Lock] = WeakValueDictionary()
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
@LionCog.listener('on_voice_state_update')
|
||||
@log_wrap(action='Voice Role Update')
|
||||
async def voicerole_update(self, member: discord.Member,
|
||||
before: discord.VoiceState, after: discord.VoiceState):
|
||||
if member.bot:
|
||||
return
|
||||
|
||||
after_channel = after.channel
|
||||
before_channel = before.channel
|
||||
if after_channel == before_channel:
|
||||
return
|
||||
|
||||
task_key = (member.guild.id, member.id)
|
||||
async with self.event_lock(task_key):
|
||||
# Get the roles of the channel they left to remove
|
||||
# Get the roles of the channel they are joining to add
|
||||
# Use a set difference to remove the roles to be added from the ones to remove
|
||||
if before_channel is not None:
|
||||
leaving_roles = await self.get_roles_for(before_channel.id)
|
||||
else:
|
||||
leaving_roles = []
|
||||
|
||||
if after_channel is not None:
|
||||
gaining_roles = await self.get_roles_for(after_channel.id)
|
||||
else:
|
||||
gaining_roles = []
|
||||
|
||||
to_remove = []
|
||||
for role in leaving_roles:
|
||||
if role in member.roles and role not in gaining_roles and role.is_assignable():
|
||||
to_remove.append(role)
|
||||
|
||||
to_add = []
|
||||
for role in gaining_roles:
|
||||
if role not in member.roles and role.is_assignable():
|
||||
to_add.append(role)
|
||||
|
||||
if to_remove:
|
||||
await member.remove_roles(*to_remove, reason="Removing voice channel associated roles.")
|
||||
if to_add:
|
||||
await member.add_roles(*to_add, reason="Adding voice channel associated roles.")
|
||||
|
||||
logger.info(
|
||||
f"Voice roles removed {len(to_remove)} roles "
|
||||
f"and added {len(to_add)} roles to <uid: {member.id}>"
|
||||
)
|
||||
|
||||
async def get_roles_for(self, channelid: int) -> list[discord.Role]:
|
||||
"""
|
||||
Get the voice roles associated to the given channel, as a list.
|
||||
|
||||
Returns an empty list if there are no associated voice roles.
|
||||
"""
|
||||
rows = await self.data.VoiceRole.fetch_where(channelid=channelid)
|
||||
channel = self.bot.get_channel(channelid)
|
||||
if not channel:
|
||||
raise ValueError("Provided voice role target channel is not in cache.")
|
||||
|
||||
target_roles = []
|
||||
for row in rows:
|
||||
role = channel.guild.get_role(row.roleid)
|
||||
if role is not None:
|
||||
target_roles.append(role)
|
||||
|
||||
return target_roles
|
||||
|
||||
def event_lock(self, key) -> asyncio.Lock:
|
||||
"""
|
||||
Get an asyncio.Lock for the given key.
|
||||
|
||||
Guarantees sequential event handling.
|
||||
"""
|
||||
lock = self._event_locks.get(key, None)
|
||||
if lock is None:
|
||||
lock = self._event_locks[key] = asyncio.Lock()
|
||||
logger.debug(f"Getting video event lock {key} (locked: {lock.locked()})")
|
||||
return lock
|
||||
|
||||
|
||||
# -------- Commands --------
|
||||
@cmds.hybrid_group(
|
||||
name='voiceroles',
|
||||
description="Base command group for voice channel -> role associationes."
|
||||
)
|
||||
@appcmds.default_permissions(manage_channels=True)
|
||||
async def voicerole_group(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@voicerole_group.command(
|
||||
name="link",
|
||||
description="Link a given voice channel with a given role."
|
||||
)
|
||||
@appcmds.describe(
|
||||
channel="The voice channel to link.",
|
||||
role="The associated role to give to members joining the voice channel."
|
||||
)
|
||||
async def voicerole_link(self, ctx: LionContext,
|
||||
channel: discord.VoiceChannel,
|
||||
role: discord.Role):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
if not channel.permissions_for(ctx.author).manage_channels:
|
||||
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
|
||||
return
|
||||
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
|
||||
await ctx.error_reply(f"You don't have the permission to manage this role!")
|
||||
return
|
||||
|
||||
await self.data.VoiceRole.table.insert(channelid=channel.id, roleid=role.id)
|
||||
await ctx.reply("Voice role associated!")
|
||||
|
||||
@voicerole_group.command(
|
||||
name="unlink",
|
||||
description="Unlink a given voice channel from a given role."
|
||||
)
|
||||
@appcmds.describe(
|
||||
channel="The voice channel to unlink.",
|
||||
role="The role to remove from this voice channel."
|
||||
)
|
||||
async def voicerole_unlink(self, ctx: LionContext,
|
||||
channel: discord.VoiceChannel,
|
||||
role: discord.Role):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
if not channel.permissions_for(ctx.author).manage_channels:
|
||||
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
|
||||
return
|
||||
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
|
||||
await ctx.error_reply(f"You don't have the permission to manage this role!")
|
||||
return
|
||||
|
||||
await self.data.VoiceRole.table.delete_where(channelid=channel.id, roleid=role.id)
|
||||
await ctx.reply("Voice role disassociated!")
|
||||
|
||||
# TODO: Display and visual editing of roles.
|
||||
|
||||
27
src/modules/voiceroles/data.py
Normal file
27
src/modules/voiceroles/data.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, Timestamp
|
||||
|
||||
|
||||
class VoiceRoleData(Registry):
|
||||
class VoiceRole(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE voice_roles(
|
||||
voice_role_id SERIAL PRIMARY KEY,
|
||||
channelid BIGINT NOT NULL,
|
||||
roleid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX voice_role_channels on voice_roles (channelid);
|
||||
"""
|
||||
# TODO: Worth associating a guildid to this as well? Denormalises though
|
||||
# Makes more theoretical sense to associated configurable channels to the guilds in a join table.
|
||||
_tablename_ = 'voice_roles'
|
||||
_cache_ = {}
|
||||
|
||||
voice_role_id = Integer(primary=True)
|
||||
channelid = Integer()
|
||||
roleid = Integer()
|
||||
|
||||
created_at = Timestamp()
|
||||
@@ -5,7 +5,7 @@ import datetime as dt
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
from discord import AllowedMentions, app_commands as appcmds
|
||||
|
||||
from data import Condition
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
@@ -654,7 +654,7 @@ class VoiceTrackerCog(LionCog):
|
||||
|
||||
# ----- Commands -----
|
||||
@cmds.hybrid_command(
|
||||
name=_p('cmd:now', "now"),
|
||||
name="tag",
|
||||
description=_p(
|
||||
'cmd:now|desc',
|
||||
"Describe what you are working on, or see what your friends are working on!"
|
||||
@@ -668,7 +668,7 @@ class VoiceTrackerCog(LionCog):
|
||||
@appcmds.describe(
|
||||
tag=_p(
|
||||
'cmd:now|param:tag|desc',
|
||||
"Describe what you are working on in 10 characters or less!"
|
||||
"Describe what you are working!"
|
||||
),
|
||||
user=_p(
|
||||
'cmd:now|param:user|desc',
|
||||
@@ -681,17 +681,15 @@ class VoiceTrackerCog(LionCog):
|
||||
)
|
||||
@appcmds.guild_only
|
||||
async def now_cmd(self, ctx: LionContext,
|
||||
tag: Optional[appcmds.Range[str, 0, 10]] = None,
|
||||
tag: Optional[str] = None,
|
||||
*,
|
||||
user: Optional[discord.Member] = None,
|
||||
clear: Optional[bool] = None
|
||||
):
|
||||
if not ctx.guild:
|
||||
return
|
||||
if not ctx.interaction:
|
||||
return
|
||||
t = self.bot.translator.t
|
||||
|
||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||
is_moderator = await moderator_ctxward(ctx)
|
||||
target = user if user is not None else ctx.author
|
||||
session = self.get_session(ctx.guild.id, target.id, create=False)
|
||||
@@ -715,7 +713,7 @@ class VoiceTrackerCog(LionCog):
|
||||
"{mention} has no running session!"
|
||||
)).format(mention=target.mention)
|
||||
)
|
||||
await ctx.interaction.edit_original_response(embed=error)
|
||||
await ctx.reply(embed=error)
|
||||
return
|
||||
|
||||
if clear:
|
||||
@@ -723,87 +721,27 @@ class VoiceTrackerCog(LionCog):
|
||||
if target == ctx.author:
|
||||
# Clear the author's tag
|
||||
await session.set_tag(None)
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:self|mode:clear|success|title',
|
||||
"Session Tag Cleared"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:self|mode:clear|success|desc',
|
||||
"Successfully unset your session tag."
|
||||
))
|
||||
)
|
||||
ack = "Cleared your current task!"
|
||||
elif not is_moderator:
|
||||
# Trying to clear someone else's tag without being a moderator
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_red(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:other|mode:clear|error:perms|title',
|
||||
"You can't do that!"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:other|mode:clear|error:perms|desc',
|
||||
"You need to be a moderator to set or clear someone else's session tag."
|
||||
))
|
||||
)
|
||||
ack = "You need to be a moderator to set or clear someone else's task!"
|
||||
else:
|
||||
# Clearing someone else's tag as a moderator
|
||||
await session.set_tag(None)
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:other|mode:clear|success|title',
|
||||
"Session Tag Cleared!"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:other|mode:clear|success|desc',
|
||||
"Cleared {target}'s session tag."
|
||||
)).format(target=target.mention)
|
||||
)
|
||||
ack = f"Cleared {target}'s current task!"
|
||||
elif tag:
|
||||
# Tag setting mode
|
||||
if target == ctx.author:
|
||||
# Set the author's tag
|
||||
await session.set_tag(tag)
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:self|mode:set|success|title',
|
||||
"Session Tag Set!"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:self|mode:set|success|desc',
|
||||
"You are now working on `{new_tag}`. Good luck!"
|
||||
)).format(new_tag=tag)
|
||||
)
|
||||
ack = f"Set your current task to `{tag}`, good luck! <:goodluck:1266447460146876497>"
|
||||
elif not is_moderator:
|
||||
# Trying the set someone else's tag without being a moderator
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_red(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:other|mode:set|error:perms|title',
|
||||
"You can't do that!"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:other|mode:set|error:perms|desc',
|
||||
"You need to be a moderator to set or clear someone else's session tag!"
|
||||
))
|
||||
)
|
||||
ack = "You need to be a moderator to set or clear someone else's task!"
|
||||
else:
|
||||
# Setting someone else's tag as a moderator
|
||||
await session.set_tag(tag)
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p(
|
||||
'cmd:now|target:other|mode:set|success|title',
|
||||
"Session Tag Set!"
|
||||
)),
|
||||
description=t(_p(
|
||||
'cmd:now|target:other|mode:set|success|desc',
|
||||
"Set {target}'s session tag to `{new_tag}`."
|
||||
)).format(target=target.mention, new_tag=tag)
|
||||
)
|
||||
ack = f"Set {target}'s current task to `{tag}`"
|
||||
else:
|
||||
# Display tag and voice time
|
||||
if target == ctx.author:
|
||||
@@ -815,14 +753,14 @@ class VoiceTrackerCog(LionCog):
|
||||
else:
|
||||
desc = t(_p(
|
||||
'cmd:now|target:self|mode:show_without_tag|desc',
|
||||
"You have been working in {channel} since {time}!\n\n"
|
||||
"You have been working in {channel} since {time}! "
|
||||
"Use `/now <tag>` to set what you are working on."
|
||||
))
|
||||
else:
|
||||
if session.tag:
|
||||
desc = t(_p(
|
||||
'cmd:now|target:other|mode:show_with_tag|desc',
|
||||
"{target} is current working in {channel}!\n"
|
||||
"{target} is current working in {channel}! "
|
||||
"They have been working on **{tag}** since {time}."
|
||||
))
|
||||
else:
|
||||
@@ -830,18 +768,13 @@ class VoiceTrackerCog(LionCog):
|
||||
'cmd:now|target:other|mode:show_without_tag|desc',
|
||||
"{target} has been working in {channel} since {time}!"
|
||||
))
|
||||
desc = desc.format(
|
||||
ack = desc.format(
|
||||
tag=session.tag,
|
||||
channel=f"<#{session.state.channelid}>",
|
||||
time=discord.utils.format_dt(session.start_time, 't'),
|
||||
time=discord.utils.format_dt(session.start_time, 'R'),
|
||||
target=target.mention,
|
||||
)
|
||||
ack = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=desc,
|
||||
timestamp=utc_now()
|
||||
)
|
||||
await ctx.interaction.edit_original_response(embed=ack)
|
||||
await ctx.reply(ack, allowed_mentions=AllowedMentions.none())
|
||||
|
||||
# ----- Configuration Commands -----
|
||||
@LionCog.placeholder_group
|
||||
|
||||
@@ -29,13 +29,26 @@ class TwitchAuthCog(LionCog):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(TwitchAuthData())
|
||||
|
||||
self.client_cache = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
# ----- Auth API -----
|
||||
|
||||
async def fetch_client_for(self, userid: int):
|
||||
...
|
||||
async def fetch_client_for(self, userid: str):
|
||||
authrow = await self.data.UserAuthRow.fetch(userid)
|
||||
if authrow is None:
|
||||
# TODO: Some user authentication error
|
||||
self.client_cache.pop(userid, None)
|
||||
raise ValueError("Requested user is not authenticated.")
|
||||
if (twitch := self.client_cache.get(userid)) is None:
|
||||
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
|
||||
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
authscopes = [AuthScope(scope) for scope in scopes]
|
||||
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
|
||||
self.client_cache[userid] = twitch
|
||||
return twitch
|
||||
|
||||
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
|
||||
"""
|
||||
@@ -46,7 +59,9 @@ class TwitchAuthCog(LionCog):
|
||||
if authrow:
|
||||
if scopes:
|
||||
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
has_auth = set(map(str, scopes)).issubset(has_scopes)
|
||||
desired = {scope.value for scope in scopes}
|
||||
has_auth = desired.issubset(has_scopes)
|
||||
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
|
||||
else:
|
||||
has_auth = True
|
||||
else:
|
||||
@@ -58,6 +73,7 @@ class TwitchAuthCog(LionCog):
|
||||
Start the user authentication flow for the given userid.
|
||||
Will request the given scopes along with the default ones and any existing scopes.
|
||||
"""
|
||||
self.client_cache.pop(userid, None)
|
||||
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
|
||||
existing = map(AuthScope, existing_strs)
|
||||
to_request = set(existing).union(scopes)
|
||||
@@ -82,3 +98,17 @@ class TwitchAuthCog(LionCog):
|
||||
await ctx.reply(flow.auth.return_auth_url())
|
||||
await flow.run()
|
||||
await ctx.reply("Authentication Complete!")
|
||||
|
||||
@cmds.hybrid_command(name='modauth')
|
||||
async def cmd_modauth(self, ctx: LionContext):
|
||||
if ctx.interaction:
|
||||
await ctx.interaction.response.defer(ephemeral=True)
|
||||
scopes = [
|
||||
AuthScope.MODERATOR_READ_FOLLOWERS,
|
||||
AuthScope.CHANNEL_READ_REDEMPTIONS,
|
||||
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
|
||||
]
|
||||
flow = await self.start_auth(scopes=scopes)
|
||||
await ctx.reply(flow.auth.return_auth_url())
|
||||
await flow.run()
|
||||
await ctx.reply("Authentication Complete!")
|
||||
|
||||
@@ -64,7 +64,7 @@ class TwitchAuthData(Registry):
|
||||
"""
|
||||
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
|
||||
|
||||
return [row.scope for row in rows] if rows else []
|
||||
return [row['scope'] for row in rows] if rows else []
|
||||
|
||||
|
||||
"""
|
||||
@@ -76,4 +76,4 @@ class TwitchAuthData(Registry):
|
||||
);
|
||||
CREATE INDEX twitch_user_scopes_userid ON twitch_user_scopes (userid);
|
||||
"""
|
||||
user_scopes = Table('twitch_token_scopes')
|
||||
user_scopes = Table('twitch_user_scopes')
|
||||
|
||||
@@ -47,7 +47,7 @@ class UserAuthFlow:
|
||||
self._setup_done.set()
|
||||
return await ws.receive_json()
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> TwitchAuthData.UserAuthRow:
|
||||
if not self._setup_done.is_set():
|
||||
raise ValueError("Cannot run UserAuthFlow before setup.")
|
||||
if self._comm_task is None:
|
||||
@@ -56,7 +56,7 @@ class UserAuthFlow:
|
||||
result = await self._comm_task
|
||||
if result.get('error', None):
|
||||
# TODO Custom auth errors
|
||||
# This is only documented to occure when the user denies the auth
|
||||
# This is only documented to occur when the user denies the auth
|
||||
raise SafeCancellation(f"Could not authenticate user! Reason: {result['error_description']}")
|
||||
|
||||
if result.get('state', None) != self.auth.state:
|
||||
|
||||
128
src/utils/lib.py
128
src/utils/lib.py
@@ -7,6 +7,7 @@ import iso8601 # type: ignore
|
||||
import pytz
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from contextvars import Context
|
||||
|
||||
import discord
|
||||
@@ -341,9 +342,9 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
|
||||
return "".join(reply_msg)
|
||||
|
||||
|
||||
def _parse_dur(time_str: str) -> int:
|
||||
def parse_dur(time_str: str) -> int:
|
||||
"""
|
||||
Parses a user provided time duration string into a timedelta object.
|
||||
Parses a user provided time duration string into an integer number of seconds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -918,3 +919,126 @@ def write_records(records: list[dict[str, Any]], stream: StringIO):
|
||||
for record in records:
|
||||
stream.write(','.join(map(str, record.values())))
|
||||
stream.write('\n')
|
||||
|
||||
|
||||
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
|
||||
"""
|
||||
Shows the user each page from the provided list `pages` one at a time,
|
||||
providing reactions to page back and forth between pages.
|
||||
This is done asynchronously, and returns after displaying the first page.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pages: List(Union(str, discord.Embed))
|
||||
A list of either strings or embeds to display as the pages.
|
||||
locked: bool
|
||||
Whether only the `ctx.author` should be able to use the paging reactions.
|
||||
kwargs: ...
|
||||
Remaining keyword arguments are transparently passed to the reply context method.
|
||||
|
||||
Returns: discord.Message
|
||||
This is the output message, returned for easy deletion.
|
||||
"""
|
||||
cancel_emoji = cross
|
||||
# Handle broken input
|
||||
if len(pages) == 0:
|
||||
raise ValueError("Pager cannot page with no pages!")
|
||||
|
||||
# Post first page. Method depends on whether the page is an embed or not.
|
||||
if isinstance(pages[start_at], discord.Embed):
|
||||
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
|
||||
else:
|
||||
out_msg = await ctx.reply(pages[start_at], **kwargs)
|
||||
|
||||
# Run the paging loop if required
|
||||
if len(pages) > 1:
|
||||
task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs))
|
||||
# ctx.tasks.append(task)
|
||||
elif add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
|
||||
# Return the output message
|
||||
return out_msg
|
||||
|
||||
|
||||
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
|
||||
"""
|
||||
Asynchronous initialiser and loop for the `pager` utility above.
|
||||
"""
|
||||
# Page number
|
||||
page = start_at
|
||||
|
||||
# Add reactions to the output message
|
||||
next_emoji = "▶"
|
||||
prev_emoji = "◀"
|
||||
cancel_emoji = cross
|
||||
|
||||
try:
|
||||
await out_msg.add_reaction(prev_emoji)
|
||||
if add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
await out_msg.add_reaction(next_emoji)
|
||||
except discord.Forbidden:
|
||||
# We don't have permission to add paging emojis
|
||||
# Die as gracefully as we can
|
||||
if ctx.guild:
|
||||
perms = ctx.channel.permissions_for(ctx.guild.me)
|
||||
if not perms.add_reactions:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `add_reactions` permission!"
|
||||
)
|
||||
elif not perms.read_message_history:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `read_message_history` permission!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results due to insufficient permissions!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results!"
|
||||
)
|
||||
return
|
||||
|
||||
# Check function to determine whether a reaction is valid
|
||||
def check(reaction, user):
|
||||
result = reaction.message.id == out_msg.id
|
||||
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
|
||||
result = result and not (user.id == ctx.bot.user.id)
|
||||
result = result and not (locked and user != ctx.author)
|
||||
return result
|
||||
|
||||
# Begin loop
|
||||
while True:
|
||||
# Wait for a valid reaction, break if we time out
|
||||
try:
|
||||
reaction, user = await ctx.bot.wait_for('reaction_add', check=check, timeout=300)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
# Attempt to remove the user's reaction, silently ignore errors
|
||||
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
|
||||
|
||||
# Change the page number
|
||||
page += 1 if reaction.emoji == next_emoji else -1
|
||||
page %= len(pages)
|
||||
|
||||
# Edit the message with the new page
|
||||
active_page = pages[page]
|
||||
if isinstance(active_page, discord.Embed):
|
||||
await out_msg.edit(embed=active_page, **kwargs)
|
||||
else:
|
||||
await out_msg.edit(content=active_page, **kwargs)
|
||||
|
||||
# Clean up by removing the reactions
|
||||
try:
|
||||
await out_msg.clear_reactions()
|
||||
except discord.Forbidden:
|
||||
try:
|
||||
await out_msg.remove_reaction(next_emoji, ctx.client.user)
|
||||
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
|
||||
except discord.NotFound:
|
||||
pass
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
7
tests/__init__.py
Normal file
7
tests/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# !/bin/python3
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||
0
tests/gui/__init__.py
Normal file
0
tests/gui/__init__.py
Normal file
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from src.cards import WeeklyGoalCard
|
||||
from gui.cards import WeeklyGoalCard
|
||||
|
||||
|
||||
async def get_card():
|
||||
card = await WeeklyGoalCard.generate_sample()
|
||||
with open('samples/weekly-sample.png', 'wb') as image_file:
|
||||
with open('output/weekly-sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
15
tests/gui/cards/pomo_sample.py
Normal file
15
tests/gui/cards/pomo_sample.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from gui.cards import BreakTimerCard, FocusTimerCard
|
||||
|
||||
|
||||
async def get_card():
|
||||
card = await BreakTimerCard.generate_sample()
|
||||
with open('output/break_timer_sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
card = await FocusTimerCard.generate_sample()
|
||||
with open('output/focus_timer_sample.png', 'wb') as image_file:
|
||||
image_file.write(card.fp.read())
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(get_card())
|
||||
Reference in New Issue
Block a user