Merge branch 'staging' into feat-profiles

This commit is contained in:
2024-10-05 08:01:48 +10:00
21 changed files with 568 additions and 201 deletions

View File

@@ -30,6 +30,7 @@ active_discord = [
'.nowdoing',
'.shoutouts',
'.tagstrings',
'.voiceroles',
]
async def setup(bot):

View File

@@ -25,6 +25,75 @@ class PERIOD(Enum):
YEAR = ('this year', 'y', 'year', 'yearly')
def counter_cmd_factory(
counter: str,
response: str,
default_period: Optional[PERIOD] = PERIOD.STREAM,
context: Optional[str] = None
):
context = context or f"cmd: {counter}"
async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await cog.parse_period(channelid, '', default=default_period)
args = (args or '').strip(" 󠀀 ")
splits = args.split(maxsplit=1)
splits = [split.strip() for split in splits if split]
details = None
amount = 1
if splits:
if splits[0].isdigit() or (splits[0].startswith('-') and splits[0][1:].isdigit()):
amount = int(splits[0])
splits = splits[1:]
if splits:
details = ' '.join(splits)
await cog.add_to_counter(
counter, userid, amount,
context=context,
details=details
)
lb = await cog.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(
response.format(
total=total,
period=period,
period_name=period.value[0],
detailsorname=details or counter,
user_total=user_total,
)
)
async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await cog.formatted_lb(counter, args, int(user.id)))
async def undo_cmd(cog, ctx: commands.Context):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
_counter = await cog.fetch_counter(counter)
query = cog.data.CounterEntry.fetch_where(
counterid=_counter.counterid,
userid=userid,
)
query.order_by('created_at', direction=ORDER.DESC)
query.limit(1)
results = await query
if not results:
await ctx.reply("Nothing to delete!")
else:
row = results[0]
await row.delete()
await ctx.reply("Undo successful!")
return (counter_cmd, lb_cmd, undo_cmd)
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
@@ -38,6 +107,7 @@ class CounterCog(LionCog):
async def cog_load(self):
self._load_twitch_methods(self.crocbot)
await self.load_counter_commands()
await self.data.init()
await self.load_counters()
@@ -46,6 +116,29 @@ class CounterCog(LionCog):
async def cog_unload(self):
self._unload_twitch_methods(self.crocbot)
async def load_counter_commands(self):
rows = await self.data.CounterCommand.fetch_where()
for row in rows:
counter = await self.data.Counter.fetch(row.counterid)
counter_cb, lb_cb, undo_cb = counter_cmd_factory(
counter.name,
row.response
)
cmds = []
main_cmd = commands.command(name=row.name)(counter_cb)
cmds.append(main_cmd)
if row.lbname:
lb_cmd = commands.command(name=row.lbname)(lb_cb)
cmds.append(lb_cmd)
if row.undoname:
undo_cmd = commands.command(name=row.undoname)(undo_cb)
cmds.append(undo_cmd)
for cmd in cmds:
self.add_twitch_command(self.crocbot, cmd)
logger.info(f"(Re)Loaded {len(rows)} counter commands!")
async def cog_check(self, ctx):
return True
@@ -80,13 +173,19 @@ class CounterCog(LionCog):
if row:
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
async def add_to_counter(
self,
counter: str, userid: int, value: int,
context: Optional[str]=None,
details: Optional[str]=None,
):
row = await self.fetch_counter(counter)
return await self.data.CounterEntry.create(
counterid=row.counterid,
userid=userid,
value=value,
context_str=context
context_str=context,
details=details
)
async def leaderboard(self, counter: str, start_time=None):
@@ -155,8 +254,43 @@ class CounterCog(LionCog):
elif subcmd == 'clear':
await self.reset_counter(name)
await ctx.reply(f"'{name}' counter reset.")
elif subcmd == 'alias':
splits = args.split(maxsplit=3) if args else []
counter = await self.fetch_counter(name)
rows = await self.data.CounterCommand.fetch_where(counterid=counter.counterid)
existing = rows[0] if rows else None
if existing and not args:
# Show current alias
await ctx.reply(
f"Counter '{name}' aliases: '!{existing.name}' to add to counter; "
f"'!{existing.lbname}' to view counter leaderboard; "
f"'!{existing.undoname}' to undo (your) last addition."
)
elif len(splits) < 4:
# Show usage
await ctx.reply(
"USAGE: !counter <name> alias <cmdname> <lbname> <undoname> <response> -- "
"Response accepts keywords {total}, {period}, {period_name}, {detailsorname}, {user_total}."
)
else:
# Create new alias
cmdname, lbname, undoname, response = splits
# Remove any existing alias
await self.data.CounterCommand.table.delete_where(name=cmdname)
alias = await self.data.CounterCommand.create(
name=cmdname,
counterid=counter.counterid,
lbname=lbname, undoname=undoname, response=response
)
await self.load_counter_commands()
await ctx.reply(
f"Alias created for counter '{name}': '!{alias.name}' to add to counter; "
f"'!{alias.lbname}' to view counter leaderboard; "
f"'!{alias.undoname}' to undo (your) last addition."
)
else:
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
if periodstr:
@@ -211,82 +345,3 @@ class CounterCog(LionCog):
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
else:
return f"{counter} {period.value[-1]} leaderboard is empty!"
# Misc actual counter commands
# TODO: Factor this out to a different module...
@commands.command()
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'tea'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: tea'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
@commands.command()
async def tealb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
@commands.command()
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'coffee'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: coffee'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
@commands.command()
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
@commands.command()
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'water'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: water'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
@commands.command()
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
@commands.command()
async def stuff(self, ctx: commands.Context, *, args: str = ''):
await ctx.reply(f"Stuff {args}")
@cmds.hybrid_command('water')
async def d_water_cmd(self, ctx):
await ctx.reply(repr(ctx))

View File

@@ -10,7 +10,8 @@ class CounterData(Registry):
CREATE TABLE counters(
counterid SERIAL PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
category TEXT
);
CREATE UNIQUE INDEX counters_name ON counters (name);
"""
@@ -19,6 +20,7 @@ class CounterData(Registry):
counterid = Integer(primary=True)
name = String()
category = String()
created_at = Timestamp()
class CounterEntry(RowModel):
@@ -31,7 +33,8 @@ class CounterData(Registry):
userid INTEGER NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
context_str TEXT
context_str TEXT,
details TEXT
);
CREATE INDEX counter_log_counterid ON counter_log (counterid);
"""
@@ -44,5 +47,28 @@ class CounterData(Registry):
value = Integer()
created_at = Timestamp()
context_str = String()
details = String()
class CounterCommand(RowModel):
"""
Schema
------
CREATE TABLE counter_commands(
name TEXT PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
lbname TEXT,
undoname TEXT,
response TEXT NOT NULL
);
"""
# NOTE: This table will be replaced by aliases soon anyway
# So no need to worry about integrity or future-proofing
_tablename_ = 'counter_commands'
_cache_ = {}
name = String(primary=True)
counterid = Integer()
lbname = String()
undoname = String()
response = String()

View File

@@ -0,0 +1,9 @@
ALTER TABLE counters ADD COLUMN category TEXT;
ALTER TABLE counter_log ADD COLUMN details TEXT;
CREATE TABLE counter_commands(
name TEXT PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
lbname TEXT,
undoname TEXT,
response TEXT NOT NULL
);

View File

@@ -47,16 +47,32 @@ class TimerChannel(Channel):
super().__init__(**kwargs)
self.cog = cog
self.channelid = 1261999440160624734
self.goal = 12
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
timer = self.cog.get_channel_timer(1261999440160624734)
if timer is not None:
await self.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
websocket=websocket,
)
await self.send_set(
**await self.get_args_for(self.channelid),
goal=self.goal,
websocket=websocket,
)
async def send_updates(self):
await self.send_set(
**await self.get_args_for(self.channelid),
goal=self.goal,
)
async def get_args_for(self, channelid):
timer = self.cog.get_channel_timer(channelid)
if timer is None:
raise ValueError(f"Timer {channelid} doesn't exist.")
return {
'start_at': timer.data.last_started,
'focus_length': timer.data.focus_length,
'break_length': timer.data.break_length,
}
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
await self.send_event({
@@ -304,8 +320,6 @@ class TimerCog(LionCog):
return
if member.bot:
return
if 1148167212901859328 not in [role.id for role in member.roles]:
return
# If a member is leaving or joining a running timer, trigger a status update
if before.channel != after.channel:
@@ -315,6 +329,7 @@ class TimerCog(LionCog):
tasks = []
if leaving is not None:
tasks.append(asyncio.create_task(leaving.update_status_card()))
leaving.last_seen.pop(member.id, None)
if joining is not None:
joining.last_seen[member.id] = utc_now()
if not joining.running and joining.auto_restart:
@@ -1059,8 +1074,18 @@ class TimerCog(LionCog):
@low_management_ward
async def streamtimer_update_cmd(self, ctx: LionContext,
new_start: Optional[str] = None,
new_goal: int = 12):
timer = self.get_channel_timer(1261999440160624734)
new_goal: Optional[int] = None,
new_channel: Optional[discord.VoiceChannel] = None,
):
if new_channel is not None:
channelid = self.channel.channelid = new_channel.id
else:
channelid = self.channel.channelid
if new_goal is not None:
self.channel.goal = new_goal
timer = self.get_channel_timer(channelid)
if timer is None:
return
if new_start:
@@ -1068,10 +1093,5 @@ class TimerCog(LionCog):
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
await timer.data.update(last_started=start_at)
await self.channel.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
goal=new_goal,
)
await self.channel.send_updates()
await ctx.reply("Stream Timer Updated")

View File

@@ -17,9 +17,19 @@ class ShoutoutCog(LionCog):
and drop a follow! \
They {areorwere} streaming {game} at {channel}
"""
COWO_SHOUTOUT = """
We think that {name} is a great coworker and you should check them out for more productive vibes! \
They {areorwere} streaming {game} at {channel}
"""
ART_SHOUTOUT = """
We think that {name} is an awesome artist and you should check them out for cool art and cosy vibes! \
They {areorwere} streaming {game} at {channel}
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.crocbot = bot.crocbot
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(ShoutoutData())
self.loaded = asyncio.Event()
@@ -59,19 +69,28 @@ class ShoutoutCog(LionCog):
return replace_multiple(text, mapping)
@commands.command(aliases=['so'])
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
async def shoutout(self, ctx: commands.Context, target: str, typ: Optional[str]=None):
# Make sure caller is mod/broadcaster
# Lookup custom shoutout for this user
# If it exists use it, otherwise use default shoutout
if (ctx.author.is_mod or ctx.author.is_broadcaster):
data = await self.data.CustomShoutout.fetch(int(user.id))
if data:
shoutout = data.content
user = await self.crocbot.seek_user(target)
if user is None:
await ctx.reply(f"Couldn't resolve '{target}' to a valid user.")
else:
shoutout = self.DEFAULT_SHOUTOUT
formatted = await self.format_shoutout(shoutout, user)
await ctx.reply(formatted)
data = await self.data.CustomShoutout.fetch(int(user.id))
if data:
shoutout = data.content
elif typ == 'cowo':
shoutout = self.COWO_SHOUTOUT
elif typ == 'art':
shoutout = self.ART_SHOUTOUT
else:
shoutout = self.DEFAULT_SHOUTOUT
formatted = await self.format_shoutout(shoutout, user)
await ctx.reply(formatted)
# TODO: How to /shoutout with lib?
# TODO Shoutout queue
@commands.command()
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):

View File

@@ -0,0 +1,7 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import VoiceRoleCog
await bot.add_cog(VoiceRoleCog(bot))

View File

@@ -0,0 +1,166 @@
from collections import defaultdict
from typing import Optional
import asyncio
from cachetools import FIFOCache
from weakref import WeakValueDictionary
import discord
from discord.abc import GuildChannel
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from utils.ui import Confirm
from . import logger
from .data import VoiceRoleData
class VoiceRoleCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(VoiceRoleData())
self._event_locks: WeakValueDictionary[tuple[int, int], asyncio.Lock] = WeakValueDictionary()
async def cog_load(self):
await self.data.init()
@LionCog.listener('on_voice_state_update')
@log_wrap(action='Voice Role Update')
async def voicerole_update(self, member: discord.Member,
before: discord.VoiceState, after: discord.VoiceState):
if member.bot:
return
after_channel = after.channel
before_channel = before.channel
if after_channel == before_channel:
return
task_key = (member.guild.id, member.id)
async with self.event_lock(task_key):
# Get the roles of the channel they left to remove
# Get the roles of the channel they are joining to add
# Use a set difference to remove the roles to be added from the ones to remove
if before_channel is not None:
leaving_roles = await self.get_roles_for(before_channel.id)
else:
leaving_roles = []
if after_channel is not None:
gaining_roles = await self.get_roles_for(after_channel.id)
else:
gaining_roles = []
to_remove = []
for role in leaving_roles:
if role in member.roles and role not in gaining_roles and role.is_assignable():
to_remove.append(role)
to_add = []
for role in gaining_roles:
if role not in member.roles and role.is_assignable():
to_add.append(role)
if to_remove:
await member.remove_roles(*to_remove, reason="Removing voice channel associated roles.")
if to_add:
await member.add_roles(*to_add, reason="Adding voice channel associated roles.")
logger.info(
f"Voice roles removed {len(to_remove)} roles "
f"and added {len(to_add)} roles to <uid: {member.id}>"
)
async def get_roles_for(self, channelid: int) -> list[discord.Role]:
"""
Get the voice roles associated to the given channel, as a list.
Returns an empty list if there are no associated voice roles.
"""
rows = await self.data.VoiceRole.fetch_where(channelid=channelid)
channel = self.bot.get_channel(channelid)
if not channel:
raise ValueError("Provided voice role target channel is not in cache.")
target_roles = []
for row in rows:
role = channel.guild.get_role(row.roleid)
if role is not None:
target_roles.append(role)
return target_roles
def event_lock(self, key) -> asyncio.Lock:
"""
Get an asyncio.Lock for the given key.
Guarantees sequential event handling.
"""
lock = self._event_locks.get(key, None)
if lock is None:
lock = self._event_locks[key] = asyncio.Lock()
logger.debug(f"Getting video event lock {key} (locked: {lock.locked()})")
return lock
# -------- Commands --------
@cmds.hybrid_group(
name='voiceroles',
description="Base command group for voice channel -> role associationes."
)
@appcmds.default_permissions(manage_channels=True)
async def voicerole_group(self, ctx: LionContext):
...
@voicerole_group.command(
name="link",
description="Link a given voice channel with a given role."
)
@appcmds.describe(
channel="The voice channel to link.",
role="The associated role to give to members joining the voice channel."
)
async def voicerole_link(self, ctx: LionContext,
channel: discord.VoiceChannel,
role: discord.Role):
if not ctx.interaction:
return
if not channel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
return
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
await ctx.error_reply(f"You don't have the permission to manage this role!")
return
await self.data.VoiceRole.table.insert(channelid=channel.id, roleid=role.id)
await ctx.reply("Voice role associated!")
@voicerole_group.command(
name="unlink",
description="Unlink a given voice channel from a given role."
)
@appcmds.describe(
channel="The voice channel to unlink.",
role="The role to remove from this voice channel."
)
async def voicerole_unlink(self, ctx: LionContext,
channel: discord.VoiceChannel,
role: discord.Role):
if not ctx.interaction:
return
if not channel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(f"You don't have the manage channels permission in {channel.mention}")
return
if not ctx.author.guild_permissions.manage_roles or not (role < ctx.author.top_role):
await ctx.error_reply(f"You don't have the permission to manage this role!")
return
await self.data.VoiceRole.table.delete_where(channelid=channel.id, roleid=role.id)
await ctx.reply("Voice role disassociated!")
# TODO: Display and visual editing of roles.

View File

@@ -0,0 +1,27 @@
from data import Registry, RowModel
from data.columns import Integer, Timestamp
class VoiceRoleData(Registry):
class VoiceRole(RowModel):
"""
Schema
------
CREATE TABLE voice_roles(
voice_role_id SERIAL PRIMARY KEY,
channelid BIGINT NOT NULL,
roleid BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX voice_role_channels on voice_roles (channelid);
"""
# TODO: Worth associating a guildid to this as well? Denormalises though
# Makes more theoretical sense to associated configurable channels to the guilds in a join table.
_tablename_ = 'voice_roles'
_cache_ = {}
voice_role_id = Integer(primary=True)
channelid = Integer()
roleid = Integer()
created_at = Timestamp()