5 Commits

Author SHA1 Message Date
7cbb6adcb8 fix(reminders): Await cog load. 2025-05-23 18:23:23 +10:00
010d52e72e fix(reminders): Load twitch methods. 2025-05-23 18:21:21 +10:00
c5e9cb1488 feat (reminders): Add simple twitch reminders. 2025-05-23 18:17:04 +10:00
d1114f1a06 (focus): Basic hyperfocus implementation. 2025-05-22 20:26:17 +10:00
2d87783c3e (twitch): Add user auth caching.
Fix issues with check_auth.
Implement fetch_client_for.
Add 'modauth' app command for basic mod scopes.
2025-05-22 20:23:34 +10:00
8 changed files with 693 additions and 6 deletions

View File

@@ -32,6 +32,8 @@ active_discord = [
'.shoutouts',
'.tagstrings',
'.voiceroles',
'.hyperfocus',
'.twreminders',
]
async def setup(bot):

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import HyperFocusCog
async def setup(bot):
await bot.add_cog(HyperFocusCog(bot))

View File

@@ -0,0 +1,306 @@
import asyncio
import json
from typing import Optional
from dataclasses import dataclass
import twitchio
from twitchio.ext import commands
from twitchAPI.type import AuthScope
import random
import datetime as dt
from datetime import timedelta, datetime
from meta import CrocBot, LionCog, LionContext, LionBot
from meta.sockets import Channel, register_channel
from utils.lib import strfdelta, utc_now
from . import logger
@dataclass
class FocusState:
userid: int | str
name: str
focus_ends: datetime
hyper: bool = True
class FocusChannel(Channel):
name = 'FocusList'
def __init__(self, cog: 'HyperFocusCog', **kwargs):
self.cog = cog
super().__init__(**kwargs)
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
await self.reload_focus(websocket=websocket)
def focus_args(self, state: FocusState):
return (
state.userid,
state.name,
state.hyper,
state.focus_ends.isoformat(),
)
async def reload_focus(self, websocket=None):
"""
Clear tasklist and re-send current tasks.
"""
await self.send_clear(websocket=websocket)
for state in self.cog.hyperfocusing.values():
await self.send_set(*self.focus_args(state), websocket=websocket)
async def send_set(self, userid, name, hyper, end_at, websocket=None):
await self.send_event({
'type': "DO",
'method': "setFocus",
'args': {
'userid': userid,
'name': name,
'hyper': hyper,
'end_at': end_at,
}
}, websocket=websocket)
async def send_del(self, userid, websocket=None):
await self.send_event({
'type': "DO",
'method': "delFocus",
'args': {
'userid': userid,
}
}, websocket=websocket)
async def send_clear(self, websocket=None):
await self.send_event({
'type': "DO",
'method': "clearFocus",
'args': {
}
}, websocket=websocket)
class HyperFocusCog(LionCog):
def __init__(self, bot: CrocBot):
self.bot = bot
self.crocbot: CrocBot = bot.crocbot
# userid -> timestamp when they stop
self.hyperfocusing: dict[str, FocusState] = {}
self.channel = FocusChannel(self)
register_channel(self.channel.name, self.channel)
self.loaded = asyncio.Event()
async def cog_load(self):
self._load_twitch_methods(self.crocbot)
self.load_hyperfocus()
self.loaded.set()
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
def save_hyperfocus(self):
with open('hyperfocus.json', 'w', encoding='utf-8') as f:
mapped = {
userid: {
'userid': str(state.userid),
'name': state.name,
'focus_ends': state.focus_ends.isoformat(),
'hyper': state.hyper
}
for userid, state in self.hyperfocusing.items()
}
json.dump(mapped, f, ensure_ascii=False, indent=4)
def load_hyperfocus(self):
with open('hyperfocus.json') as f:
mapped = json.load(f)
self.hyperfocusing.clear()
for userid, map in mapped.items():
self.hyperfocusing[str(userid)] = FocusState(
userid=str(map['userid']),
name=map['name'],
hyper=map['hyper'],
focus_ends=dt.datetime.fromisoformat(map['focus_ends'])
)
print(f"Loaded hyperfocus: {self.hyperfocusing}")
def check_hyperfocus(self, userid):
"""
Returns whether a user is currently in HYPERFOCUS mode!
"""
return (state := self.hyperfocusing.get(userid, None)) and utc_now() < state.focus_ends
@commands.Cog.event('event_message')
async def on_message(self, message: twitchio.Message):
if message.content and message.content.lower() == 'nice':
await message.channel.send("That's Nice")
await self.good_croccy_handler(message)
tags = message.tags
if tags and message.content and self.check_hyperfocus(tags.get('user-id')):
if not self.valid_focus_message(message):
logger.info(
f"Deleting message from hyperfocused user. {message.raw_data=}"
)
await asyncio.sleep(1)
msgid = tags['id']
# TODO: Better selection for moderator
# i.e. if the message is not from the broadcaster and we do have delete perms
# then use our own token.
broadcasterid = tags['room-id']
authcog = self.bot.get_cog('TwitchAuthCog')
if not await authcog.check_auth(broadcasterid, scopes=[AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES]):
await message.channel.send(f"@{message.author.name} Stay focused! (I tried to delete your message because you are in !hyperfocus. Unfortunately I don't have the permissions to do that. But stay focused anyway!)")
else:
twitch = await authcog.fetch_client_for(broadcasterid)
await twitch.delete_chat_message(
broadcasterid,
broadcasterid,
msgid,
)
await message.channel.send(
f"@{message.author.name} Stay focused! (I deleted your message because you are in !hyperfocus, use !unfocus to come back.)"
)
async def good_croccy_handler(self, message: twitchio.Message):
if not message.content:
return
cleaned = message.content.lower().replace('@croccyhelper', '').strip()
if cleaned in ('good croc', 'good croccy', 'good helper'):
await message.channel.send("holono1Heart")
elif cleaned in ('bad croc', 'bad croccy', 'bad helper'):
await message.channel.send("holono1Sad")
async def chemical_handler(self, message: twitchio.Message):
if not message.content:
return
cleaned = message.content.lower().strip()
if cleaned in ('oh',):
await message.channel.send('Oxygen Hydrogen!')
def valid_focus_message(self, message: twitchio.Message) -> bool:
"""
Determined whether the given message is allowed to be sent in !hyperfocus.
That is, if it appears to be emote-only or a command.
"""
content = message.content
if not content:
return True
tags = message.tags or {}
to_remove = []
if (replying := tags.get('reply-parent-user-login', '')) and content.startswith('@'):
# Trim the mention from the start of the content
splits = content.split(maxsplit=1)
to_remove.append((0, len(splits[0])))
if emotesstr := tags.get('emotes', ''):
for emotestr in emotesstr.split('/'):
emote, locs = emotestr.split(':')
for loc in locs.split(','):
start, end = loc.split('-')
to_remove.append((int(start), int(end) + 1))
# Sort the pairs to remove by descending starting index
# This should allow clean removal with a loop as long as there are no intersections.
to_remove.sort(key=lambda pair: pair[0], reverse=True)
for start, end in to_remove:
content = content[:start] + content[end:]
content = content.strip().replace(' ', '').replace('\n', '')
allowed = not content or content.startswith('!') or content.startswith('*')
allowed = allowed or all(not char.isascii() for char in content)
if not allowed:
logger.info(f"Invalid hyperfocus message. Trimmed content: {content}")
return allowed
@commands.command(name='coinflip')
async def coinflip(self, ctx):
await ctx.reply(random.choice(('heads', 'tails')))
@commands.command(name='choose')
async def choose(self, ctx, *, args: str):
if not args:
await ctx.reply("Give me something to choose, e.g. !choose Heads | Tails")
else:
options = args.split('|')
options = [option.strip() for option in options]
options = [option for option in options if option]
choice = random.choice(options)
if random.random() < 0.01:
choice = "You"
await ctx.reply(f"I choose: {choice}")
@commands.command(name='hyperfocus')
async def hyperfocus_cmd(self, ctx, dur: Optional[int] = None):
userid = str(ctx.author.id)
now = utc_now()
end_time = None
if dur is None:
# Automatically select time
next_hour = now.replace(minute=0, second=0, microsecond=0) + dt.timedelta(hours=1)
next_block = next_hour - dt.timedelta(minutes=10)
if now > next_block:
# Currently in the break
next_block = next_block + dt.timedelta(hours=1)
end_time = next_block
dur = int((end_time - now).total_seconds() // 60)
elif dur > 720:
await ctx.reply("You can hyperfocus for at most 12 hours at a time!")
else:
end_time = utc_now() + dt.timedelta(minutes=dur)
if end_time is not None:
state = self.hyperfocusing[userid] = FocusState(
userid=userid,
name=ctx.author.display_name,
focus_ends=end_time,
)
self.save_hyperfocus()
await self.channel.send_set(*self.channel.focus_args(state))
await ctx.reply(
f"{ctx.author.name} has gone into HYPERFOCUS mode! "
f"They will be in emote and command only mode for the next {dur} minutes! "
"Use !unfocus if you really need to chat before then, best of luck! 🍀"
)
@commands.command(name='unfocus')
async def unfocus_cmd(self, ctx):
self.hyperfocusing.pop(ctx.author.id, None)
self.save_hyperfocus()
await self.channel.send_del(ctx.author.id)
await ctx.reply("Welcome back from focus, hope it went well! Have a comfy break and remember to have a sippie and a stretch~")
@commands.command(name='hyperfocused')
async def focused_cmd(self, ctx, user: Optional[twitchio.User] = None):
user = user if user is not None else ctx.author
userid = str(user.id)
if self.check_hyperfocus(userid):
state = self.hyperfocusing.get(userid)
end_time = state.focus_ends
durstr = strfdelta(end_time - utc_now())
await ctx.reply(
f"{user.name} is in HYPERFOCUS for another {durstr}! "
"They can only write emojis and commands in this time~ "
"(use !unfocus to come back if you need to!) "
"Good luck!"
)
elif userid != str(ctx.author.id):
await ctx.reply(
f"{user.name} is not hyperfocused!"
)
else:
await ctx.reply(
"You are not hyperfocused! "
"Enter HYPERFOCUS mode for e.g. 10 minutes by writing !hyperfocus 10"
)

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import ReminderCog
async def setup(bot):
await bot.add_cog(ReminderCog(bot))

View File

@@ -0,0 +1,333 @@
import asyncio
import json
import re
import itertools
from typing import Optional
from dataclasses import dataclass
from collections import defaultdict
import twitchio
from twitchio.ext import commands
import datetime as dt
from datetime import timedelta, datetime
from meta import CrocBot, LionCog, LionContext, LionBot
from utils.lib import strfdelta, utc_now, parse_dur
from . import logger
reminder_regex = re.compile(
r"""
(^)?(?P<type> (?: \b in) | (?: every))
\s*(?P<duration> (?: day| hour| (?:\d+\s*(?:(?:d|h|m|s)[a-zA-Z]*)?(?:\s|and)*)+))
(?:(?(1) (?:, | ; | : | \. | to)? | $))
""",
re.IGNORECASE | re.VERBOSE | re.DOTALL
)
@dataclass
class Reminder:
userid: int
content: str
name: str
channel: str
remind_at: datetime
class ReminderCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.crocbot: CrocBot = bot.crocbot
self.loaded = asyncio.Event()
self.reminders: dict[int, list[Reminder]] = defaultdict(list)
self.next_reminder_task = None
self._reminder_wait_task = None
self.reminder_lock = asyncio.Lock()
async def cog_load(self):
await self.load_reminders()
self._load_twitch_methods(self.crocbot)
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
def save_reminders(self):
with open('reminders.json', 'w', encoding='utf-8') as f:
mapped = {
int(userid): [
{
'userid': int(state.userid),
'name': state.name,
'channel': state.channel,
'content': state.content,
'remind_at': state.remind_at.isoformat(),
}
for state in states
]
for userid, states in self.reminders.items()
}
json.dump(mapped, f, ensure_ascii=False, indent=4)
async def load_reminders(self):
if self.next_reminder_task and not self.next_reminder_task.cancelled():
self.next_reminder_task.cancel()
self.next_reminder_task = None
with open('reminders.json') as f:
mapped = json.load(f)
self.reminders.clear()
for userid, states in mapped.items():
userid = int(userid)
for map in states:
reminder = Reminder(
userid=int(map['userid']),
content=map['content'],
name=map['name'],
channel=map['channel'],
remind_at=dt.datetime.fromisoformat(map['remind_at'])
)
self.reminders[userid].append(reminder)
self.schedule_next_reminder()
logger.info(f"Loaded reminders: {self.reminders}")
def schedule_next_reminder(self):
"""
Schedule the next reminder in the queue, if it exists, and return it.
Cancels any currently running task.
"""
if not self.reminders:
return None
next_reminder = min(
itertools.chain(*self.reminders.values()), key=lambda r: r.remind_at, default=None
)
if next_reminder:
self.next_reminder_task = asyncio.create_task(self.run_reminder(next_reminder))
else:
# We still need to cancel any ongoing reminders
if self._reminder_wait_task and not self._reminder_wait_task.cancelled():
self._reminder_wait_task.cancel()
async def run_reminder(self, reminder: Reminder):
"""
Wait for and then run the given reminder.
Expects to be cancelled if another reminder is scheduled earlier.
"""
# Cancel the next reminder wait task.
# If the next reminder is currently executing/firing,
# this will do nothing and we will wait until it is finished.
if self._reminder_wait_task and not self._reminder_wait_task.cancelled():
self._reminder_wait_task.cancel()
# This ensures that only one reminder task runs at once
async with self.reminder_lock:
now = utc_now()
to_wait = (reminder.remind_at - now).total_seconds()
try:
self._reminder_wait_task = asyncio.create_task(asyncio.sleep(to_wait))
await self._reminder_wait_task
except asyncio.CancelledError:
# Reminder task was cancelled
raise
# Now fire the reminder
await self.fire_reminder(reminder)
# And schedule the next reminder if needed
self.schedule_next_reminder()
async def fire_reminder(self, reminder: Reminder):
"""
Actually run the given reminder.
"""
# Check that this reminder is still valid
if reminder not in self.reminders[reminder.userid]:
logger.error(f"Reminder {reminder!r} is firing but not scheduled!")
return
# We don't want to reschedule while a reminder is running
# Get the channel to send to
destination = self.crocbot.get_channel(reminder.channel)
if destination is None:
logger.info(f"Reminder couldn't get channel '{reminder.channel}'. Trying again in a minute.")
# In case we aren't actually ready yet
await self.crocbot.wait_for_ready()
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
logger.info("Cancelling channel wait task for reminder.")
raise
destination = self.crocbot.get_channel(reminder.channel)
if destination is None:
# This means we haven't joined the channel
logger.warning(f"Reminder couldn't get channel '{reminder.channel}' for the second time. Cancelling.")
else:
logger.info(f"Channel '{reminder.channel}' found as {destination}. Continuing.")
if destination is not None:
# Send the reminder
msg = f"@{reminder.name}, you asked me to remind you: {reminder.content}"
await destination.send(msg)
# This should really be based on a reminderid but oh well
# It's theoretically possible for a reminder to be scheduled at the same time as it is run
# In which case the wrong reminder will be removed.
self.reminders[reminder.userid].remove(reminder)
self.save_reminders()
def get_reminders_for(self, userid: int):
return self.reminders.get(userid, [])
@commands.command(name='remindme', aliases=['reminders', 'reminder'])
async def remindme_cmd(self, ctx, *, args: str=''):
args = args.strip()
userid = int(ctx.author.id)
existing = self.get_reminders_for(userid)
existing.sort(key=lambda r: r.remind_at, reverse=False)
now = utc_now()
if not args or args.lower() in ('show', 'list'):
# Show user's current reminders or show usage
if not existing:
await ctx.reply(
"USAGE: !remindme <task> in <dur> EG: !remindme Coffee is ready in 10m | !remindme in 10m, Coffee is ready"
)
elif len(existing) == 1:
reminder = existing[0]
dur = reminder.remind_at - now
sec = (dur.total_seconds()) < 60
formatted_dur = strfdelta(dur, short=False, sec=sec)
await ctx.reply(
f"I will remind you about '{reminder.content}' in about {formatted_dur}. Use !remindme cancel to cancel!"
)
else:
parts = []
for i, reminder in enumerate(existing, start=1):
dur = reminder.remind_at - now
sec = (dur.total_seconds()) < 60
formatted_dur = strfdelta(dur, short=True, sec=sec)
parts.append(
f"{i}: '{reminder.content}' in {formatted_dur}"
)
remstr = '; '.join(parts)
if len(remstr) > 290:
remstr = remstr[:290] + '...'
await ctx.reply(
f"Active Reminders: {remstr}. Use '!remindme cancel n' or '!remindme clear' to remove!"
)
elif args.lower() in ('clear', 'clearall', 'remove all'):
# Remove all reminders
if existing:
self.reminders.pop(userid, None)
self.save_reminders()
self.schedule_next_reminder()
else:
await ctx.reply("You don't have any reminders set!")
elif args.lower().split(maxsplit=1)[0] in ('remove', 'cancel'):
splits = args.split(maxsplit=1)
remaining = splits[1].strip() if len(splits) > 1 else ''
# Remove a specified reminder
to_remove = None
if not existing:
await ctx.reply("You don't have any reminders set!")
elif len(existing) == 1:
to_remove = existing[0]
elif remaining.isdigit():
# Try to the remove the reminder with the give number
given = int(remaining)
if given > len(existing):
await ctx.reply(f"You only have {len(existing)} reminders!")
else:
to_remove = existing[given - 1]
else:
# Invalid arguments, show usage
await ctx.reply(
"USAGE: !remindme cancel <number>, e.g. !remindme cancel 1 to cancel your first reminder!"
)
if to_remove is not None:
self.reminders[userid].remove(to_remove)
await ctx.reply(
f"Cancelled your reminder '{to_remove.content}'"
)
self.save_reminders()
self.schedule_next_reminder()
else:
# Parse for reminder
content = None
duration = None
repeating = None
# First parse it
match = re.search(reminder_regex, args)
if match:
repeating = match.group('type').lower() == 'every'
duration_str = match.group('duration').lower()
if duration_str.isdigit():
# Default to minutes if no unit given
duration = int(duration_str) * 60
elif duration_str in ('day', 'a day'):
duration = 24 * 60 * 60
elif duration_str in ('hour', 'an hour'):
duration = 60 * 60
else:
duration = parse_dur(duration_str)
content = (args[:match.start()] + args[match.end():]).strip()
if content.startswith('to '):
content = content[3:].strip()
else:
# Legacy parsing, without requiring "in" at the front
splits = args.split(maxsplit=1)
if len(splits) == 2 and splits[0].isdigit():
repeating = False
duration = int(splits[0]) * 60
content = splits[1].strip()
# Sanity checking
if not duration or not content:
return await ctx.reply(
"Sorry, I didn't understand your reminder! Please use e.g. !remindme Coffee is ready in 10m"
)
if repeating:
return await ctx.reply(
"Sorry, we don't support repeating reminders right now!"
)
if len(existing) > 10:
return await ctx.reply(
"Sorry, you can only have 10 active reminders! Use !remindme cancel or !remindme clear to cancel some!"
)
reminder = Reminder(
userid=userid,
content=content,
name=ctx.author.name,
channel=ctx.channel.name,
remind_at=now + timedelta(seconds=duration)
)
self.reminders[userid].append(reminder)
dur = reminder.remind_at - now
sec = (dur.total_seconds()) < 60
formatted_dur = strfdelta(dur, short=False, sec=sec)
msg = f"Got it! I will remind you in {formatted_dur}!"
await ctx.reply(msg)
self.save_reminders()
self.schedule_next_reminder()

View File

@@ -29,13 +29,26 @@ class TwitchAuthCog(LionCog):
self.bot = bot
self.data = bot.db.load_registry(TwitchAuthData())
self.client_cache = {}
async def cog_load(self):
await self.data.init()
# ----- Auth API -----
async def fetch_client_for(self, userid: int):
...
async def fetch_client_for(self, userid: str):
authrow = await self.data.UserAuthRow.fetch(userid)
if authrow is None:
# TODO: Some user authentication error
self.client_cache.pop(userid, None)
raise ValueError("Requested user is not authenticated.")
if (twitch := self.client_cache.get(userid)) is None:
twitch = await Twitch(self.bot.config.twitch['app_id'], self.bot.config.twitch['app_secret'])
scopes = await self.data.UserAuthRow.get_scopes_for(userid)
authscopes = [AuthScope(scope) for scope in scopes]
await twitch.set_user_authentication(authrow.access_token, authscopes, authrow.refresh_token)
self.client_cache[userid] = twitch
return twitch
async def check_auth(self, userid: str, scopes: list[AuthScope] = []) -> bool:
"""
@@ -46,7 +59,9 @@ class TwitchAuthCog(LionCog):
if authrow:
if scopes:
has_scopes = await self.data.UserAuthRow.get_scopes_for(userid)
has_auth = set(map(str, scopes)).issubset(has_scopes)
desired = {scope.value for scope in scopes}
has_auth = desired.issubset(has_scopes)
logger.info(f"Auth check for `{userid}`: Requested scopes {desired}, has scopes {has_scopes}. Passed: {has_auth}")
else:
has_auth = True
else:
@@ -58,6 +73,7 @@ class TwitchAuthCog(LionCog):
Start the user authentication flow for the given userid.
Will request the given scopes along with the default ones and any existing scopes.
"""
self.client_cache.pop(userid, None)
existing_strs = await self.data.UserAuthRow.get_scopes_for(userid)
existing = map(AuthScope, existing_strs)
to_request = set(existing).union(scopes)
@@ -82,3 +98,17 @@ class TwitchAuthCog(LionCog):
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")
@cmds.hybrid_command(name='modauth')
async def cmd_modauth(self, ctx: LionContext):
if ctx.interaction:
await ctx.interaction.response.defer(ephemeral=True)
scopes = [
AuthScope.MODERATOR_READ_FOLLOWERS,
AuthScope.CHANNEL_READ_REDEMPTIONS,
AuthScope.MODERATOR_MANAGE_CHAT_MESSAGES,
]
flow = await self.start_auth(scopes=scopes)
await ctx.reply(flow.auth.return_auth_url())
await flow.run()
await ctx.reply("Authentication Complete!")

View File

@@ -64,7 +64,7 @@ class TwitchAuthData(Registry):
"""
rows = await TwitchAuthData.user_scopes.select_where(userid=userid)
return [row.scope for row in rows] if rows else []
return [row['scope'] for row in rows] if rows else []
"""

View File

@@ -342,9 +342,9 @@ def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -
return "".join(reply_msg)
def _parse_dur(time_str: str) -> int:
def parse_dur(time_str: str) -> int:
"""
Parses a user provided time duration string into a timedelta object.
Parses a user provided time duration string into an integer number of seconds.
Parameters
----------