Initial merger with Twitch interface.
This commit is contained in:
13
src/modules/counters/__init__.py
Normal file
13
src/modules/counters/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import CounterCog
|
||||
|
||||
def prepare(bot):
|
||||
bot.add_cog(CounterCog(bot))
|
||||
|
||||
async def setup(bot):
|
||||
from .lion_cog import CounterCog
|
||||
|
||||
await bot.add_cog(CounterCog(bot))
|
||||
299
src/modules/counters/cog.py
Normal file
299
src/modules/counters/cog.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import CrocBot
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import CounterData
|
||||
|
||||
|
||||
class PERIOD(Enum):
|
||||
ALL = ('', 'all', 'all-time')
|
||||
STREAM = ('this stream', 'stream',)
|
||||
DAY = ('today', 'd', 'day', 'today', 'daily')
|
||||
WEEK = ('this week', 'w', 'week', 'weekly')
|
||||
MONTH = ('this month', 'm', 'mo', 'month', 'monthly')
|
||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||
|
||||
|
||||
class CounterCog(commands.Cog):
|
||||
def __init__(self, bot: CrocBot):
|
||||
self.bot = bot
|
||||
self.data = bot.data.load_registry(CounterData())
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
# Cache of counter names -> rows
|
||||
self.counters = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
self.loaded.set()
|
||||
|
||||
async def load_counters(self):
|
||||
"""
|
||||
Initialise counter name cache.
|
||||
"""
|
||||
rows = await self.data.Counter.fetch_where()
|
||||
self.counters = {row.name: row for row in rows}
|
||||
logger.info(
|
||||
f"Loaded {len(self.counters)} counters."
|
||||
)
|
||||
|
||||
async def ensure_loaded(self):
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
@commands.Cog.event('event_ready') # type: ignore
|
||||
async def on_ready(self):
|
||||
await self.ensure_loaded()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
await self.ensure_loaded()
|
||||
return True
|
||||
|
||||
# Counters API
|
||||
|
||||
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
||||
"""
|
||||
Fetches the Counter with the given name,
|
||||
or creates it if it doesn't exist.
|
||||
"""
|
||||
if (row := self.counters.get(counter, None)) is None:
|
||||
row = await self.data.Counter.fetch_or_create(name=counter)
|
||||
self.counters[counter] = row
|
||||
return row
|
||||
|
||||
async def delete_counter(self, counter: str):
|
||||
self.counters.pop(counter, None)
|
||||
await self.data.Counter.table.delete_where(name=counter)
|
||||
|
||||
async def reset_counter(self, counter: str):
|
||||
row = self.counters.get(counter, None)
|
||||
if row:
|
||||
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
||||
|
||||
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
||||
row = await self.fetch_counter(counter)
|
||||
return await self.data.CounterEntry.create(
|
||||
counterid=row.counterid,
|
||||
userid=userid,
|
||||
value=value,
|
||||
context_str=context
|
||||
)
|
||||
|
||||
async def leaderboard(self, counter: str, start_time=None):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
||||
query.select('userid', user_total="SUM(value)")
|
||||
query.group_by('userid')
|
||||
query.order_by('user_total', ORDER.DESC)
|
||||
if start_time is not None:
|
||||
query.where(self.data.CounterEntry.created_at >= start_time)
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
lb = {result['userid']: result['user_total'] for result in results}
|
||||
|
||||
return lb
|
||||
|
||||
async def personal_total(self, counter: str, userid: int):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid, userid=userid)
|
||||
query.select(user_total="SUM(value)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
return results[0]['user_total'] if results else 0
|
||||
|
||||
async def totals(self, counter):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
||||
query.select(counter_total="SUM(value)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
return results[0]['counter_total'] if results else 0
|
||||
|
||||
# Counters commands
|
||||
@commands.command()
|
||||
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
|
||||
name = name.lower()
|
||||
|
||||
if subcmd is None or subcmd == 'show':
|
||||
# Show
|
||||
total = await self.totals(name)
|
||||
await ctx.reply(f"'{name}' counter is: {total}")
|
||||
elif subcmd == 'add':
|
||||
if args is None:
|
||||
value = 1
|
||||
else:
|
||||
try:
|
||||
value = int(args)
|
||||
except ValueError:
|
||||
await ctx.reply(f"Could not parse value to add.")
|
||||
return
|
||||
await self.add_to_counter(
|
||||
name,
|
||||
int(ctx.author.id),
|
||||
value,
|
||||
context='cmd: counter add'
|
||||
)
|
||||
total = await self.totals(name)
|
||||
await ctx.reply(f"'{name}' counter is now: {total}")
|
||||
elif subcmd == 'lb':
|
||||
user = await ctx.channel.user()
|
||||
lbstr = await self.formatted_lb(name, args or '', int(user.id))
|
||||
await ctx.reply(lbstr)
|
||||
elif subcmd == 'clear':
|
||||
await self.reset_counter(name)
|
||||
await ctx.reply(f"'{name}' counter reset.")
|
||||
else:
|
||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
||||
|
||||
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
||||
if periodstr:
|
||||
period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
|
||||
if period is None:
|
||||
raise ValueError("Invalid period string provided")
|
||||
else:
|
||||
period = default
|
||||
|
||||
now = utc_now()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if period is PERIOD.ALL:
|
||||
start_time = None
|
||||
elif period is PERIOD.STREAM:
|
||||
streams = await self.bot.fetch_streams(user_ids=[userid])
|
||||
if streams:
|
||||
stream = streams[0]
|
||||
start_time = stream.started_at
|
||||
else:
|
||||
period = PERIOD.ALL
|
||||
start_time = None
|
||||
elif period is PERIOD.DAY:
|
||||
start_time = today
|
||||
elif period is PERIOD.WEEK:
|
||||
start_time = today - timedelta(days=today.weekday())
|
||||
elif period is PERIOD.MONTH:
|
||||
start_time = today.replace(day=1)
|
||||
elif period is PERIOD.YEAR:
|
||||
start_time = today.replace(day=1, month=1)
|
||||
else:
|
||||
period = PERIOD.ALL
|
||||
start_time = None
|
||||
|
||||
return (period, start_time)
|
||||
|
||||
async def formatted_lb(self, counter: str, periodstr: str, channelid: int):
|
||||
|
||||
period, start_time = await self.parse_period(channelid, periodstr)
|
||||
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
if lb:
|
||||
userids = list(lb.keys())
|
||||
users = await self.bot.fetch_users(ids=userids)
|
||||
name_map = {user.id: user.display_name for user in users}
|
||||
parts = []
|
||||
for userid, total in lb.items():
|
||||
name = name_map.get(userid, str(userid))
|
||||
part = f"{name}: {total}"
|
||||
parts.append(part)
|
||||
lbstr = '; '.join(parts)
|
||||
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
||||
else:
|
||||
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||
|
||||
# Misc actual counter commands
|
||||
# TODO: Factor this out to a different module...
|
||||
@commands.command()
|
||||
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'tea'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: tea'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'coffee'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: coffee'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'water'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: water'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def reload(self, ctx: commands.Context, *, args: str = ''):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
if not args:
|
||||
await ctx.reply("Full reload not implemented yet.")
|
||||
else:
|
||||
try:
|
||||
self.bot.reload_module(args)
|
||||
except Exception:
|
||||
logger.exception("Failed to reload")
|
||||
await ctx.reply("Failed to reload module! Check console~")
|
||||
else:
|
||||
await ctx.reply("Reloaded!")
|
||||
|
||||
48
src/modules/counters/data.py
Normal file
48
src/modules/counters/data.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class CounterData(Registry):
|
||||
class Counter(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE counters(
|
||||
counterid SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||
"""
|
||||
_tablename_ = 'counters'
|
||||
_cache_ = {}
|
||||
|
||||
counterid = Integer(primary=True)
|
||||
name = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
class CounterEntry(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE counter_log(
|
||||
entryid SERIAL PRIMARY KEY,
|
||||
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
userid INTEGER NOT NULL,
|
||||
value INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
context_str TEXT
|
||||
);
|
||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||
"""
|
||||
_tablename_ = 'counter_log'
|
||||
_cache_ = {}
|
||||
|
||||
entryid = Integer(primary=True)
|
||||
counterid = Integer()
|
||||
userid = Integer()
|
||||
value = Integer()
|
||||
created_at = Timestamp()
|
||||
context_str = String()
|
||||
|
||||
|
||||
23
src/modules/counters/lion_cog.py
Normal file
23
src/modules/counters/lion_cog.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import UserInputError
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
from data.conditions import NULL
|
||||
|
||||
from . import logger
|
||||
from .data import CounterData
|
||||
|
||||
|
||||
class CounterCog(LionCog):
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
self.counter_cog = bot.crocbot.get_cog('CounterCog')
|
||||
Reference in New Issue
Block a user