293 lines
10 KiB
Python
293 lines
10 KiB
Python
import asyncio
|
|
from enum import Enum
|
|
from typing import Optional
|
|
from datetime import timedelta
|
|
|
|
import discord
|
|
from discord.ext import commands as cmds
|
|
import twitchio
|
|
from twitchio.ext import commands
|
|
|
|
|
|
from data.queries import ORDER
|
|
from meta import LionCog, LionBot, CrocBot
|
|
from utils.lib import utc_now
|
|
from . import logger
|
|
from .data import CounterData
|
|
|
|
|
|
class PERIOD(Enum):
|
|
ALL = ('', 'all', 'all-time')
|
|
STREAM = ('this stream', 'stream',)
|
|
DAY = ('today', 'd', 'day', 'today', 'daily')
|
|
WEEK = ('this week', 'w', 'week', 'weekly')
|
|
MONTH = ('this month', 'm', 'mo', 'month', 'monthly')
|
|
YEAR = ('this year', 'y', 'year', 'yearly')
|
|
|
|
|
|
class CounterCog(LionCog):
|
|
def __init__(self, bot: LionBot):
|
|
self.bot = bot
|
|
self.crocbot: CrocBot = bot.crocbot
|
|
self.data = bot.db.load_registry(CounterData())
|
|
|
|
self.loaded = asyncio.Event()
|
|
|
|
# Cache of counter names -> rows
|
|
self.counters = {}
|
|
|
|
async def cog_load(self):
|
|
self._load_twitch_methods(self.crocbot)
|
|
|
|
await self.data.init()
|
|
await self.load_counters()
|
|
self.loaded.set()
|
|
|
|
async def cog_unload(self):
|
|
self._unload_twitch_methods(self.crocbot)
|
|
|
|
async def cog_check(self, ctx):
|
|
return True
|
|
|
|
async def load_counters(self):
|
|
"""
|
|
Initialise counter name cache.
|
|
"""
|
|
rows = await self.data.Counter.fetch_where()
|
|
self.counters = {row.name: row for row in rows}
|
|
logger.info(
|
|
f"Loaded {len(self.counters)} counters."
|
|
)
|
|
|
|
# Counters API
|
|
|
|
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
|
"""
|
|
Fetches the Counter with the given name,
|
|
or creates it if it doesn't exist.
|
|
"""
|
|
if (row := self.counters.get(counter, None)) is None:
|
|
row = await self.data.Counter.fetch_or_create(name=counter)
|
|
self.counters[counter] = row
|
|
return row
|
|
|
|
async def delete_counter(self, counter: str):
|
|
self.counters.pop(counter, None)
|
|
await self.data.Counter.table.delete_where(name=counter)
|
|
|
|
async def reset_counter(self, counter: str):
|
|
row = self.counters.get(counter, None)
|
|
if row:
|
|
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
|
|
|
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
|
row = await self.fetch_counter(counter)
|
|
return await self.data.CounterEntry.create(
|
|
counterid=row.counterid,
|
|
userid=userid,
|
|
value=value,
|
|
context_str=context
|
|
)
|
|
|
|
async def leaderboard(self, counter: str, start_time=None):
|
|
row = await self.fetch_counter(counter)
|
|
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
|
query.select('userid', user_total="SUM(value)")
|
|
query.group_by('userid')
|
|
query.order_by('user_total', ORDER.DESC)
|
|
if start_time is not None:
|
|
query.where(self.data.CounterEntry.created_at >= start_time)
|
|
query.with_no_adapter()
|
|
results = await query
|
|
lb = {result['userid']: result['user_total'] for result in results}
|
|
|
|
return lb
|
|
|
|
async def personal_total(self, counter: str, userid: int):
|
|
row = await self.fetch_counter(counter)
|
|
query = self.data.CounterEntry.table.select_where(counterid=row.counterid, userid=userid)
|
|
query.select(user_total="SUM(value)")
|
|
query.with_no_adapter()
|
|
results = await query
|
|
return results[0]['user_total'] if results else 0
|
|
|
|
async def totals(self, counter):
|
|
row = await self.fetch_counter(counter)
|
|
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
|
query.select(counter_total="SUM(value)")
|
|
query.with_no_adapter()
|
|
results = await query
|
|
return results[0]['counter_total'] if results else 0
|
|
|
|
# Counters commands
|
|
@commands.command()
|
|
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
|
|
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
|
return
|
|
|
|
name = name.lower()
|
|
|
|
if subcmd is None or subcmd == 'show':
|
|
# Show
|
|
total = await self.totals(name)
|
|
await ctx.reply(f"'{name}' counter is: {total}")
|
|
elif subcmd == 'add':
|
|
if args is None:
|
|
value = 1
|
|
else:
|
|
try:
|
|
value = int(args)
|
|
except ValueError:
|
|
await ctx.reply(f"Could not parse value to add.")
|
|
return
|
|
await self.add_to_counter(
|
|
name,
|
|
int(ctx.author.id),
|
|
value,
|
|
context='cmd: counter add'
|
|
)
|
|
total = await self.totals(name)
|
|
await ctx.reply(f"'{name}' counter is now: {total}")
|
|
elif subcmd == 'lb':
|
|
user = await ctx.channel.user()
|
|
lbstr = await self.formatted_lb(name, args or '', int(user.id))
|
|
await ctx.reply(lbstr)
|
|
elif subcmd == 'clear':
|
|
await self.reset_counter(name)
|
|
await ctx.reply(f"'{name}' counter reset.")
|
|
else:
|
|
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
|
|
|
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
|
if periodstr:
|
|
period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
|
|
if period is None:
|
|
raise ValueError("Invalid period string provided")
|
|
else:
|
|
period = default
|
|
|
|
now = utc_now()
|
|
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
if period is PERIOD.ALL:
|
|
start_time = None
|
|
elif period is PERIOD.STREAM:
|
|
streams = await self.crocbot.fetch_streams(user_ids=[userid])
|
|
if streams:
|
|
stream = streams[0]
|
|
start_time = stream.started_at
|
|
else:
|
|
period = PERIOD.ALL
|
|
start_time = None
|
|
elif period is PERIOD.DAY:
|
|
start_time = today
|
|
elif period is PERIOD.WEEK:
|
|
start_time = today - timedelta(days=today.weekday())
|
|
elif period is PERIOD.MONTH:
|
|
start_time = today.replace(day=1)
|
|
elif period is PERIOD.YEAR:
|
|
start_time = today.replace(day=1, month=1)
|
|
else:
|
|
period = PERIOD.ALL
|
|
start_time = None
|
|
|
|
return (period, start_time)
|
|
|
|
async def formatted_lb(self, counter: str, periodstr: str, channelid: int):
|
|
|
|
period, start_time = await self.parse_period(channelid, periodstr)
|
|
|
|
lb = await self.leaderboard(counter, start_time=start_time)
|
|
if lb:
|
|
userids = list(lb.keys())
|
|
users = await self.crocbot.fetch_users(ids=userids)
|
|
name_map = {user.id: user.display_name for user in users}
|
|
parts = []
|
|
for userid, total in lb.items():
|
|
name = name_map.get(userid, str(userid))
|
|
part = f"{name}: {total}"
|
|
parts.append(part)
|
|
lbstr = '; '.join(parts)
|
|
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
|
else:
|
|
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
|
|
|
# Misc actual counter commands
|
|
# TODO: Factor this out to a different module...
|
|
@commands.command()
|
|
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
userid = int(ctx.author.id)
|
|
channelid = int((await ctx.channel.user()).id)
|
|
period, start_time = await self.parse_period(channelid, '')
|
|
counter = 'tea'
|
|
|
|
await self.add_to_counter(
|
|
counter,
|
|
userid,
|
|
1,
|
|
context='cmd: tea'
|
|
)
|
|
lb = await self.leaderboard(counter, start_time=start_time)
|
|
user_total = lb.get(userid, 0)
|
|
total = sum(lb.values())
|
|
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
|
|
|
@commands.command()
|
|
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
|
user = await ctx.channel.user()
|
|
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
|
|
|
@commands.command()
|
|
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
userid = int(ctx.author.id)
|
|
channelid = int((await ctx.channel.user()).id)
|
|
period, start_time = await self.parse_period(channelid, '')
|
|
counter = 'coffee'
|
|
|
|
await self.add_to_counter(
|
|
counter,
|
|
userid,
|
|
1,
|
|
context='cmd: coffee'
|
|
)
|
|
lb = await self.leaderboard(counter, start_time=start_time)
|
|
user_total = lb.get(userid, 0)
|
|
total = sum(lb.values())
|
|
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
|
|
|
@commands.command()
|
|
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
|
user = await ctx.channel.user()
|
|
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
|
|
|
@commands.command()
|
|
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
|
userid = int(ctx.author.id)
|
|
channelid = int((await ctx.channel.user()).id)
|
|
period, start_time = await self.parse_period(channelid, '')
|
|
counter = 'water'
|
|
|
|
await self.add_to_counter(
|
|
counter,
|
|
userid,
|
|
1,
|
|
context='cmd: water'
|
|
)
|
|
lb = await self.leaderboard(counter, start_time=start_time)
|
|
user_total = lb.get(userid, 0)
|
|
total = sum(lb.values())
|
|
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
|
|
|
@commands.command()
|
|
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
|
user = await ctx.channel.user()
|
|
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
|
|
|
@commands.command()
|
|
async def stuff(self, ctx: commands.Context, *, args: str = ''):
|
|
await ctx.reply(f"Stuff {args}")
|
|
|
|
@cmds.hybrid_command('water')
|
|
async def d_water_cmd(self, ctx):
|
|
await ctx.reply(repr(ctx))
|