Merge branch 'feat-counter-refactor' into staging
This commit is contained in:
@@ -1454,6 +1454,7 @@ CREATE TABLE shoutouts(
|
|||||||
CREATE TABLE counters(
|
CREATE TABLE counters(
|
||||||
counterid SERIAL PRIMARY KEY,
|
counterid SERIAL PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
|
category TEXT,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||||
@@ -1464,6 +1465,7 @@ CREATE TABLE counter_log(
|
|||||||
userid INTEGER NOT NULL,
|
userid INTEGER NOT NULL,
|
||||||
value INTEGER NOT NULL,
|
value INTEGER NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
details TEXT,
|
||||||
context_str TEXT
|
context_str TEXT
|
||||||
);
|
);
|
||||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||||
|
|||||||
@@ -47,6 +47,27 @@ class LionCog(Cog):
|
|||||||
|
|
||||||
return await super()._inject(bot, *args, *kwargs)
|
return await super()._inject(bot, *args, *kwargs)
|
||||||
|
|
||||||
|
def add_twitch_command(self, bot: Bot, command: Command):
|
||||||
|
"""
|
||||||
|
Dynamically register a command with the given bot.
|
||||||
|
|
||||||
|
The command will be deregistered on cog unload.
|
||||||
|
"""
|
||||||
|
# Remove any conflicting commands
|
||||||
|
if cmd := bot.get_command(command.name):
|
||||||
|
bot.remove_command(cmd.name)
|
||||||
|
self._twitch_cmds_.pop(command.name, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._twitch_cmds_[command.name] = command
|
||||||
|
command._instance = self
|
||||||
|
command.cog = self
|
||||||
|
bot.add_command(command)
|
||||||
|
except Exception:
|
||||||
|
# Ensure the command doesn't die in the internal command cache
|
||||||
|
self._twitch_cmds_.pop(command.name, None)
|
||||||
|
raise
|
||||||
|
|
||||||
def _load_twitch_methods(self, bot: Bot):
|
def _load_twitch_methods(self, bot: Bot):
|
||||||
for name, command in self._twitch_cmds_.items():
|
for name, command in self._twitch_cmds_.items():
|
||||||
command._instance = self
|
command._instance = self
|
||||||
|
|||||||
@@ -25,6 +25,75 @@ class PERIOD(Enum):
|
|||||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||||
|
|
||||||
|
|
||||||
|
def counter_cmd_factory(
|
||||||
|
counter: str,
|
||||||
|
response: str,
|
||||||
|
default_period: Optional[PERIOD] = PERIOD.STREAM,
|
||||||
|
context: Optional[str] = None
|
||||||
|
):
|
||||||
|
context = context or f"cmd: {counter}"
|
||||||
|
async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None):
|
||||||
|
userid = int(ctx.author.id)
|
||||||
|
channelid = int((await ctx.channel.user()).id)
|
||||||
|
period, start_time = await cog.parse_period(userid, '', default=default_period)
|
||||||
|
|
||||||
|
args = (args or '').strip(" ")
|
||||||
|
splits = args.split(maxsplit=1)
|
||||||
|
splits = [split.strip() for split in splits if split]
|
||||||
|
|
||||||
|
details = None
|
||||||
|
amount = 1
|
||||||
|
|
||||||
|
if splits:
|
||||||
|
if splits[0].isdigit() or (splits[0].startswith('-') and splits[0][1:].isdigit()):
|
||||||
|
amount = int(splits[0])
|
||||||
|
splits = splits[1:]
|
||||||
|
if splits:
|
||||||
|
details = ' '.join(splits)
|
||||||
|
|
||||||
|
await cog.add_to_counter(
|
||||||
|
counter, userid, amount,
|
||||||
|
context=context,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
lb = await cog.leaderboard(counter, start_time=start_time)
|
||||||
|
user_total = lb.get(userid, 0)
|
||||||
|
total = sum(lb.values())
|
||||||
|
await ctx.reply(
|
||||||
|
response.format(
|
||||||
|
total=total,
|
||||||
|
period=period,
|
||||||
|
period_name=period.value[0],
|
||||||
|
detailsorname=details or counter,
|
||||||
|
user_total=user_total,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''):
|
||||||
|
user = await ctx.channel.user()
|
||||||
|
await ctx.reply(await cog.formatted_lb(counter, args, int(user.id)))
|
||||||
|
|
||||||
|
async def undo_cmd(cog, ctx: commands.Context):
|
||||||
|
userid = int(ctx.author.id)
|
||||||
|
channelid = int((await ctx.channel.user()).id)
|
||||||
|
_counter = await cog.fetch_counter(counter)
|
||||||
|
query = cog.data.CounterEntry.fetch_where(
|
||||||
|
counterid=_counter.counterid,
|
||||||
|
userid=userid,
|
||||||
|
)
|
||||||
|
query.order_by('created_at', direction=ORDER.DESC)
|
||||||
|
query.limit(1)
|
||||||
|
results = await query
|
||||||
|
if not results:
|
||||||
|
await ctx.reply("Nothing to delete!")
|
||||||
|
else:
|
||||||
|
row = results[0]
|
||||||
|
await row.delete()
|
||||||
|
await ctx.reply("Undo successful!")
|
||||||
|
|
||||||
|
return (counter_cmd, lb_cmd, undo_cmd)
|
||||||
|
|
||||||
|
|
||||||
class CounterCog(LionCog):
|
class CounterCog(LionCog):
|
||||||
def __init__(self, bot: LionBot):
|
def __init__(self, bot: LionBot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
@@ -38,6 +107,7 @@ class CounterCog(LionCog):
|
|||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
self._load_twitch_methods(self.crocbot)
|
self._load_twitch_methods(self.crocbot)
|
||||||
|
await self.load_counter_commands()
|
||||||
|
|
||||||
await self.data.init()
|
await self.data.init()
|
||||||
await self.load_counters()
|
await self.load_counters()
|
||||||
@@ -46,6 +116,55 @@ class CounterCog(LionCog):
|
|||||||
async def cog_unload(self):
|
async def cog_unload(self):
|
||||||
self._unload_twitch_methods(self.crocbot)
|
self._unload_twitch_methods(self.crocbot)
|
||||||
|
|
||||||
|
async def load_counter_commands(self):
|
||||||
|
rows = await self.data.CounterCommand.fetch_where()
|
||||||
|
for row in rows:
|
||||||
|
counter = await self.data.Counter.fetch(row.counterid)
|
||||||
|
counter_cb, lb_cb, undo_cb = counter_cmd_factory(
|
||||||
|
counter.name,
|
||||||
|
row.response
|
||||||
|
)
|
||||||
|
cmds = []
|
||||||
|
main_cmd = commands.command(name=row.name)(counter_cb)
|
||||||
|
cmds.append(main_cmd)
|
||||||
|
if row.lbname:
|
||||||
|
lb_cmd = commands.command(name=row.lbname)(lb_cb)
|
||||||
|
cmds.append(lb_cmd)
|
||||||
|
if row.undoname:
|
||||||
|
undo_cmd = commands.command(name=row.undoname)(undo_cb)
|
||||||
|
cmds.append(undo_cmd)
|
||||||
|
|
||||||
|
for cmd in cmds:
|
||||||
|
self.add_twitch_command(self.crocbot, cmd)
|
||||||
|
|
||||||
|
logger.info(f"(Re)Loaded {len(rows)} counter commands!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# commands = {
|
||||||
|
# 'stuff': (
|
||||||
|
# 'stuffcounter',
|
||||||
|
# 'stufflb',
|
||||||
|
# "Good luck with {detailsorname}! We have done {total} stuffs {period_name}."
|
||||||
|
# ),
|
||||||
|
# 'water': (
|
||||||
|
# 'water',
|
||||||
|
# 'waterlb',
|
||||||
|
# "Good job hydrating! We have had {total} cups of tea {period_name}."
|
||||||
|
# ),
|
||||||
|
# 'coffee': (
|
||||||
|
# 'coffee',
|
||||||
|
# 'coffeelb',
|
||||||
|
# "Enjoy your {detailsorname}! We have had {total} cups of coffee {period_name}."
|
||||||
|
# ),
|
||||||
|
# 'tea': (
|
||||||
|
# 'tea',
|
||||||
|
# 'tealb',
|
||||||
|
# "Enjoy your {detailsorname}! We have had {total} cups of tea this {period_name}."
|
||||||
|
# ),
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
async def cog_check(self, ctx):
|
async def cog_check(self, ctx):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -80,13 +199,19 @@ class CounterCog(LionCog):
|
|||||||
if row:
|
if row:
|
||||||
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
||||||
|
|
||||||
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
async def add_to_counter(
|
||||||
|
self,
|
||||||
|
counter: str, userid: int, value: int,
|
||||||
|
context: Optional[str]=None,
|
||||||
|
details: Optional[str]=None,
|
||||||
|
):
|
||||||
row = await self.fetch_counter(counter)
|
row = await self.fetch_counter(counter)
|
||||||
return await self.data.CounterEntry.create(
|
return await self.data.CounterEntry.create(
|
||||||
counterid=row.counterid,
|
counterid=row.counterid,
|
||||||
userid=userid,
|
userid=userid,
|
||||||
value=value,
|
value=value,
|
||||||
context_str=context
|
context_str=context,
|
||||||
|
details=details
|
||||||
)
|
)
|
||||||
|
|
||||||
async def leaderboard(self, counter: str, start_time=None):
|
async def leaderboard(self, counter: str, start_time=None):
|
||||||
@@ -155,8 +280,43 @@ class CounterCog(LionCog):
|
|||||||
elif subcmd == 'clear':
|
elif subcmd == 'clear':
|
||||||
await self.reset_counter(name)
|
await self.reset_counter(name)
|
||||||
await ctx.reply(f"'{name}' counter reset.")
|
await ctx.reply(f"'{name}' counter reset.")
|
||||||
|
elif subcmd == 'alias':
|
||||||
|
splits = args.split(maxsplit=3) if args else []
|
||||||
|
counter = await self.fetch_counter(name)
|
||||||
|
rows = await self.data.CounterCommand.fetch_where(counterid=counter.counterid)
|
||||||
|
existing = rows[0] if rows else None
|
||||||
|
if existing and not args:
|
||||||
|
# Show current alias
|
||||||
|
await ctx.reply(
|
||||||
|
f"Counter '{name}' aliases: '!{existing.name}' to add to counter; "
|
||||||
|
f"'!{existing.lbname}' to view counter leaderboard; "
|
||||||
|
f"'!{existing.undoname}' to undo (your) last addition."
|
||||||
|
)
|
||||||
|
elif len(splits) < 4:
|
||||||
|
# Show usage
|
||||||
|
await ctx.reply(
|
||||||
|
"USAGE: !counter <name> alias <cmdname> <lbname> <undoname> <response> -- "
|
||||||
|
"Response accepts keywords {total}, {period}, {period_name}, {detailsorname}, {user_total}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new alias
|
||||||
|
cmdname, lbname, undoname, response = splits
|
||||||
|
# Remove any existing alias
|
||||||
|
await self.data.CounterCommand.table.delete_where(name=cmdname)
|
||||||
|
|
||||||
|
alias = await self.data.CounterCommand.create(
|
||||||
|
name=cmdname,
|
||||||
|
counterid=counter.counterid,
|
||||||
|
lbname=lbname, undoname=undoname, response=response
|
||||||
|
)
|
||||||
|
await self.load_counter_commands()
|
||||||
|
await ctx.reply(
|
||||||
|
f"Alias created for counter '{name}': '!{alias.name}' to add to counter; "
|
||||||
|
f"'!{alias.lbname}' to view counter leaderboard; "
|
||||||
|
f"'!{alias.undoname}' to undo (your) last addition."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
|
||||||
|
|
||||||
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
||||||
if periodstr:
|
if periodstr:
|
||||||
@@ -211,82 +371,3 @@ class CounterCog(LionCog):
|
|||||||
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
||||||
else:
|
else:
|
||||||
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||||
|
|
||||||
# Misc actual counter commands
|
|
||||||
# TODO: Factor this out to a different module...
|
|
||||||
@commands.command()
|
|
||||||
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'tea'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: tea'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'coffee'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: coffee'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
||||||
userid = int(ctx.author.id)
|
|
||||||
channelid = int((await ctx.channel.user()).id)
|
|
||||||
period, start_time = await self.parse_period(channelid, '')
|
|
||||||
counter = 'water'
|
|
||||||
|
|
||||||
await self.add_to_counter(
|
|
||||||
counter,
|
|
||||||
userid,
|
|
||||||
1,
|
|
||||||
context='cmd: water'
|
|
||||||
)
|
|
||||||
lb = await self.leaderboard(counter, start_time=start_time)
|
|
||||||
user_total = lb.get(userid, 0)
|
|
||||||
total = sum(lb.values())
|
|
||||||
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
user = await ctx.channel.user()
|
|
||||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
|
||||||
|
|
||||||
@commands.command()
|
|
||||||
async def stuff(self, ctx: commands.Context, *, args: str = ''):
|
|
||||||
await ctx.reply(f"Stuff {args}")
|
|
||||||
|
|
||||||
@cmds.hybrid_command('water')
|
|
||||||
async def d_water_cmd(self, ctx):
|
|
||||||
await ctx.reply(repr(ctx))
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ class CounterData(Registry):
|
|||||||
CREATE TABLE counters(
|
CREATE TABLE counters(
|
||||||
counterid SERIAL PRIMARY KEY,
|
counterid SERIAL PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
category TEXT
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||||
"""
|
"""
|
||||||
@@ -19,6 +20,7 @@ class CounterData(Registry):
|
|||||||
|
|
||||||
counterid = Integer(primary=True)
|
counterid = Integer(primary=True)
|
||||||
name = String()
|
name = String()
|
||||||
|
category = String()
|
||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
|
|
||||||
class CounterEntry(RowModel):
|
class CounterEntry(RowModel):
|
||||||
@@ -31,7 +33,8 @@ class CounterData(Registry):
|
|||||||
userid INTEGER NOT NULL,
|
userid INTEGER NOT NULL,
|
||||||
value INTEGER NOT NULL,
|
value INTEGER NOT NULL,
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
context_str TEXT
|
context_str TEXT,
|
||||||
|
details TEXT
|
||||||
);
|
);
|
||||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||||
"""
|
"""
|
||||||
@@ -44,5 +47,28 @@ class CounterData(Registry):
|
|||||||
value = Integer()
|
value = Integer()
|
||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
context_str = String()
|
context_str = String()
|
||||||
|
details = String()
|
||||||
|
|
||||||
|
class CounterCommand(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE counter_commands(
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
lbname TEXT,
|
||||||
|
undoname TEXT,
|
||||||
|
response TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
# NOTE: This table will be replaced by aliases soon anyway
|
||||||
|
# So no need to worry about integrity or future-proofing
|
||||||
|
_tablename_ = 'counter_commands'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
name = String(primary=True)
|
||||||
|
counterid = Integer()
|
||||||
|
lbname = String()
|
||||||
|
undoname = String()
|
||||||
|
response = String()
|
||||||
|
|
||||||
|
|||||||
9
src/modules/counters/migration.sql
Normal file
9
src/modules/counters/migration.sql
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
ALTER TABLE counters ADD COLUMN category TEXT;
|
||||||
|
ALTER TABLE counter_log ADD COLUMN details TEXT;
|
||||||
|
CREATE TABLE counter_commands(
|
||||||
|
name TEXT PRIMARY KEY,
|
||||||
|
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
lbname TEXT,
|
||||||
|
undoname TEXT,
|
||||||
|
response TEXT NOT NULL
|
||||||
|
);
|
||||||
Reference in New Issue
Block a user