Initial merger with Twitch interface.
This commit is contained in:
Binary file not shown.
BIN
assets/pomodoro/break_alert_orig.wav
Normal file
BIN
assets/pomodoro/break_alert_orig.wav
Normal file
Binary file not shown.
BIN
assets/pomodoro/chime.mp3
Normal file
BIN
assets/pomodoro/chime.mp3
Normal file
Binary file not shown.
Binary file not shown.
BIN
assets/pomodoro/focus_alert_orig.wav
Normal file
BIN
assets/pomodoro/focus_alert_orig.wav
Normal file
Binary file not shown.
@@ -1406,6 +1406,65 @@ CREATE TABLE stream_alerts(
|
||||
-- }}}
|
||||
|
||||
|
||||
-- Nowlist {{{
|
||||
|
||||
CREATE TABLE nowlist_tasks(
|
||||
userid BIGINT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
task TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL,
|
||||
done_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Shoutouts {{{
|
||||
|
||||
CREATE TABLE shoutouts(
|
||||
userid BIGINT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Counters {{{
|
||||
|
||||
CREATE TABLE counters(
|
||||
counterid SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||
|
||||
CREATE TABLE counter_log(
|
||||
entryid SERIAL PRIMARY KEY,
|
||||
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
userid INTEGER NOT NULL,
|
||||
value INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
context_str TEXT
|
||||
);
|
||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||
-- }}}
|
||||
|
||||
-- Tags {{{
|
||||
|
||||
CREATE TABLE channel_tags(
|
||||
tagid SERIAL PRIMARY KEY,
|
||||
channelid BIGINT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_by BIGINT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name);
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
|
||||
-- Analytics Data {{{
|
||||
CREATE SCHEMA "analytics";
|
||||
|
||||
|
||||
79
src/bot.py
79
src/bot.py
@@ -1,15 +1,17 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from meta import LionBot, conf, sharding, appname, shard_talk
|
||||
from meta import CrocBot, LionBot, conf, sharding, appname, shard_talk, sockets, args
|
||||
from meta.app import shardname
|
||||
from meta.logger import log_context, log_action_stack, setup_main_logger
|
||||
from meta.context import ctx_bot
|
||||
from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus
|
||||
from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus, SystemMonitor
|
||||
|
||||
from data import Database
|
||||
|
||||
@@ -58,18 +60,28 @@ async def main():
|
||||
intents.message_content = True
|
||||
intents.presences = False
|
||||
|
||||
async with db.open():
|
||||
system_monitor = SystemMonitor()
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(db.open())
|
||||
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||
logger.critical(error)
|
||||
raise RuntimeError(error)
|
||||
system_monitor.add_component(ComponentMonitor('Database', _data_monitor))
|
||||
|
||||
translator = LeoBabel()
|
||||
ctx_translator.set(translator)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with LionBot(
|
||||
session = await stack.enter_async_context(aiohttp.ClientSession())
|
||||
await stack.enter_async_context(
|
||||
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
|
||||
)
|
||||
|
||||
lionbot = await stack.enter_async_context(
|
||||
LionBot(
|
||||
command_prefix='!',
|
||||
intents=intents,
|
||||
appname=appname,
|
||||
@@ -81,7 +93,7 @@ async def main():
|
||||
'modules',
|
||||
'babel',
|
||||
'tracking.voice', 'tracking.text',
|
||||
],
|
||||
],
|
||||
web_client=session,
|
||||
app_ipc=shard_talk,
|
||||
testing_guilds=conf.bot.getintlist('admin_guilds'),
|
||||
@@ -91,18 +103,49 @@ async def main():
|
||||
proxy=conf.bot.get('proxy', None),
|
||||
translator=translator,
|
||||
chunk_guilds_at_startup=False,
|
||||
) as lionbot:
|
||||
ctx_bot.set(lionbot)
|
||||
lionbot.system_monitor.add_component(
|
||||
ComponentMonitor('Database', _data_monitor)
|
||||
)
|
||||
try:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
||||
await lionbot.start(conf.bot['TOKEN'])
|
||||
except asyncio.CancelledError:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||
system_monitor=system_monitor,
|
||||
)
|
||||
)
|
||||
|
||||
crocbot = CrocBot(
|
||||
config=conf,
|
||||
data=db,
|
||||
prefix='!',
|
||||
initial_channels=conf.croccy.getlist('initial_channels'),
|
||||
token=conf.croccy['token'],
|
||||
lionbot=lionbot
|
||||
)
|
||||
lionbot.crocbot = crocbot
|
||||
|
||||
crocbot.load_module('modules')
|
||||
|
||||
crocstart = asyncio.create_task(start_croccy(crocbot))
|
||||
lionstart = asyncio.create_task(start_lion(lionbot))
|
||||
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
|
||||
crocstart.cancel()
|
||||
lionstart.cancel()
|
||||
|
||||
async def start_lion(lionbot):
|
||||
ctx_bot.set(lionbot)
|
||||
try:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
||||
await lionbot.start(conf.bot['TOKEN'])
|
||||
except asyncio.CancelledError:
|
||||
log_context.set(f"APP: {appname}")
|
||||
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||
|
||||
async def start_croccy(crocbot):
|
||||
try:
|
||||
log_context.set(f"APP: {appname}-croccy")
|
||||
logger.info("Starting Twitch bot.", extra={'action': 'Starting'})
|
||||
await crocbot.start()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Croccybot shutting down gracefully.")
|
||||
except Exception:
|
||||
logger.exception("Croccybot shutting down ungracefully.")
|
||||
finally:
|
||||
await crocbot.close()
|
||||
|
||||
|
||||
def _main():
|
||||
|
||||
32
src/meta/CrocBot.py
Normal file
32
src/meta/CrocBot.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import logging
|
||||
|
||||
from twitchio.ext import commands
|
||||
from twitchio.ext import pubsub
|
||||
|
||||
from data import Database
|
||||
|
||||
from .config import Conf
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrocBot(commands.Bot):
|
||||
def __init__(self, *args,
|
||||
config: Conf,
|
||||
data: Database,
|
||||
lionbot: 'LionBot', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.data = data
|
||||
self.pubsub = pubsub.PubSubPool(self)
|
||||
self.lionbot = lionbot
|
||||
|
||||
async def event_ready(self):
|
||||
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||
@@ -56,7 +56,9 @@ class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf, translator: LeoBabel,
|
||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||
testing_guilds: List[int] = [], **kwargs
|
||||
testing_guilds: List[int] = [],
|
||||
system_monitor: Optional[SystemMonitor] = None,
|
||||
**kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -71,7 +73,7 @@ class LionBot(Bot):
|
||||
self.app_ipc = app_ipc
|
||||
self.translator = translator
|
||||
|
||||
self.system_monitor = SystemMonitor()
|
||||
self.system_monitor = system_monitor or SystemMonitor()
|
||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||
self.system_monitor.add_component(self.monitor)
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from .LionCog import LionCog
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
|
||||
from .CrocBot import CrocBot
|
||||
|
||||
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||
from .config import conf, configEmoji
|
||||
from .args import args
|
||||
@@ -10,6 +12,7 @@ from .app import appname, shard_talk, appname_from_shard, shard_from_appname
|
||||
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||
from .context import context, ctx_bot
|
||||
|
||||
from . import sockets
|
||||
from . import sharding
|
||||
from . import logger
|
||||
from . import app
|
||||
|
||||
68
src/meta/sockets.py
Normal file
68
src/meta/sockets.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from typing import Any
|
||||
import logging
|
||||
|
||||
import websockets
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class Channel(ABC):
|
||||
"""
|
||||
A channel is a stateful connection handler for a group of connected websockets.
|
||||
"""
|
||||
name = "Root Channel"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.connections = set()
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return not self.connections
|
||||
|
||||
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
|
||||
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
|
||||
self.connections.add(websocket)
|
||||
|
||||
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
|
||||
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
|
||||
self.connections.discard(websocket)
|
||||
|
||||
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
|
||||
raise NotImplementedError
|
||||
|
||||
async def send_event(self, event, websocket=None):
|
||||
message = json.dumps(event)
|
||||
if not websocket:
|
||||
for ws in self.connections:
|
||||
await ws.send(message)
|
||||
else:
|
||||
await websocket.send(message)
|
||||
|
||||
channels = {}
|
||||
|
||||
def register_channel(name, channel: Channel):
|
||||
channels[name] = channel
|
||||
|
||||
|
||||
async def root_handler(websocket: websockets.WebSocketServerProtocol):
|
||||
message = await websocket.recv()
|
||||
event = json.loads(message)
|
||||
|
||||
if event.get('type', None) != 'init':
|
||||
raise ValueError("Received Websocket connection with no init.")
|
||||
|
||||
if (channel_name := event.get('channel', None)) not in channels:
|
||||
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
|
||||
channel = channels[channel_name]
|
||||
|
||||
try:
|
||||
await channel.on_connection(websocket, event)
|
||||
async for message in websocket:
|
||||
await channel.handle_message(websocket, message)
|
||||
finally:
|
||||
await channel.del_connection(websocket)
|
||||
@@ -1,6 +1,6 @@
|
||||
this_package = 'modules'
|
||||
|
||||
active = [
|
||||
active_discord = [
|
||||
'.sysadmin',
|
||||
'.config',
|
||||
'.user_config',
|
||||
@@ -28,7 +28,18 @@ active = [
|
||||
'.test',
|
||||
]
|
||||
|
||||
active_twitch = [
|
||||
'.nowdoing',
|
||||
'.shoutouts',
|
||||
'.counters',
|
||||
'.tagstrings',
|
||||
]
|
||||
|
||||
|
||||
def prepare(bot):
|
||||
for ext in active_twitch:
|
||||
bot.load_module(this_package + ext)
|
||||
|
||||
async def setup(bot):
|
||||
for ext in active:
|
||||
for ext in active_discord:
|
||||
await bot.load_extension(ext, package=this_package)
|
||||
|
||||
13
src/modules/counters/__init__.py
Normal file
13
src/modules/counters/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import CounterCog
|
||||
|
||||
def prepare(bot):
|
||||
bot.add_cog(CounterCog(bot))
|
||||
|
||||
async def setup(bot):
|
||||
from .lion_cog import CounterCog
|
||||
|
||||
await bot.add_cog(CounterCog(bot))
|
||||
299
src/modules/counters/cog.py
Normal file
299
src/modules/counters/cog.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from datetime import timedelta
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from data.queries import ORDER
|
||||
from meta import CrocBot
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import CounterData
|
||||
|
||||
|
||||
class PERIOD(Enum):
|
||||
ALL = ('', 'all', 'all-time')
|
||||
STREAM = ('this stream', 'stream',)
|
||||
DAY = ('today', 'd', 'day', 'today', 'daily')
|
||||
WEEK = ('this week', 'w', 'week', 'weekly')
|
||||
MONTH = ('this month', 'm', 'mo', 'month', 'monthly')
|
||||
YEAR = ('this year', 'y', 'year', 'yearly')
|
||||
|
||||
|
||||
class CounterCog(commands.Cog):
|
||||
def __init__(self, bot: CrocBot):
|
||||
self.bot = bot
|
||||
self.data = bot.data.load_registry(CounterData())
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
# Cache of counter names -> rows
|
||||
self.counters = {}
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
self.loaded.set()
|
||||
|
||||
async def load_counters(self):
|
||||
"""
|
||||
Initialise counter name cache.
|
||||
"""
|
||||
rows = await self.data.Counter.fetch_where()
|
||||
self.counters = {row.name: row for row in rows}
|
||||
logger.info(
|
||||
f"Loaded {len(self.counters)} counters."
|
||||
)
|
||||
|
||||
async def ensure_loaded(self):
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
@commands.Cog.event('event_ready') # type: ignore
|
||||
async def on_ready(self):
|
||||
await self.ensure_loaded()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
await self.ensure_loaded()
|
||||
return True
|
||||
|
||||
# Counters API
|
||||
|
||||
async def fetch_counter(self, counter: str) -> CounterData.Counter:
|
||||
"""
|
||||
Fetches the Counter with the given name,
|
||||
or creates it if it doesn't exist.
|
||||
"""
|
||||
if (row := self.counters.get(counter, None)) is None:
|
||||
row = await self.data.Counter.fetch_or_create(name=counter)
|
||||
self.counters[counter] = row
|
||||
return row
|
||||
|
||||
async def delete_counter(self, counter: str):
|
||||
self.counters.pop(counter, None)
|
||||
await self.data.Counter.table.delete_where(name=counter)
|
||||
|
||||
async def reset_counter(self, counter: str):
|
||||
row = self.counters.get(counter, None)
|
||||
if row:
|
||||
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
|
||||
|
||||
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
|
||||
row = await self.fetch_counter(counter)
|
||||
return await self.data.CounterEntry.create(
|
||||
counterid=row.counterid,
|
||||
userid=userid,
|
||||
value=value,
|
||||
context_str=context
|
||||
)
|
||||
|
||||
async def leaderboard(self, counter: str, start_time=None):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
||||
query.select('userid', user_total="SUM(value)")
|
||||
query.group_by('userid')
|
||||
query.order_by('user_total', ORDER.DESC)
|
||||
if start_time is not None:
|
||||
query.where(self.data.CounterEntry.created_at >= start_time)
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
lb = {result['userid']: result['user_total'] for result in results}
|
||||
|
||||
return lb
|
||||
|
||||
async def personal_total(self, counter: str, userid: int):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid, userid=userid)
|
||||
query.select(user_total="SUM(value)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
return results[0]['user_total'] if results else 0
|
||||
|
||||
async def totals(self, counter):
|
||||
row = await self.fetch_counter(counter)
|
||||
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
|
||||
query.select(counter_total="SUM(value)")
|
||||
query.with_no_adapter()
|
||||
results = await query
|
||||
return results[0]['counter_total'] if results else 0
|
||||
|
||||
# Counters commands
|
||||
@commands.command()
|
||||
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
|
||||
name = name.lower()
|
||||
|
||||
if subcmd is None or subcmd == 'show':
|
||||
# Show
|
||||
total = await self.totals(name)
|
||||
await ctx.reply(f"'{name}' counter is: {total}")
|
||||
elif subcmd == 'add':
|
||||
if args is None:
|
||||
value = 1
|
||||
else:
|
||||
try:
|
||||
value = int(args)
|
||||
except ValueError:
|
||||
await ctx.reply(f"Could not parse value to add.")
|
||||
return
|
||||
await self.add_to_counter(
|
||||
name,
|
||||
int(ctx.author.id),
|
||||
value,
|
||||
context='cmd: counter add'
|
||||
)
|
||||
total = await self.totals(name)
|
||||
await ctx.reply(f"'{name}' counter is now: {total}")
|
||||
elif subcmd == 'lb':
|
||||
user = await ctx.channel.user()
|
||||
lbstr = await self.formatted_lb(name, args or '', int(user.id))
|
||||
await ctx.reply(lbstr)
|
||||
elif subcmd == 'clear':
|
||||
await self.reset_counter(name)
|
||||
await ctx.reply(f"'{name}' counter reset.")
|
||||
else:
|
||||
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
|
||||
|
||||
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
|
||||
if periodstr:
|
||||
period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
|
||||
if period is None:
|
||||
raise ValueError("Invalid period string provided")
|
||||
else:
|
||||
period = default
|
||||
|
||||
now = utc_now()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if period is PERIOD.ALL:
|
||||
start_time = None
|
||||
elif period is PERIOD.STREAM:
|
||||
streams = await self.bot.fetch_streams(user_ids=[userid])
|
||||
if streams:
|
||||
stream = streams[0]
|
||||
start_time = stream.started_at
|
||||
else:
|
||||
period = PERIOD.ALL
|
||||
start_time = None
|
||||
elif period is PERIOD.DAY:
|
||||
start_time = today
|
||||
elif period is PERIOD.WEEK:
|
||||
start_time = today - timedelta(days=today.weekday())
|
||||
elif period is PERIOD.MONTH:
|
||||
start_time = today.replace(day=1)
|
||||
elif period is PERIOD.YEAR:
|
||||
start_time = today.replace(day=1, month=1)
|
||||
else:
|
||||
period = PERIOD.ALL
|
||||
start_time = None
|
||||
|
||||
return (period, start_time)
|
||||
|
||||
async def formatted_lb(self, counter: str, periodstr: str, channelid: int):
|
||||
|
||||
period, start_time = await self.parse_period(channelid, periodstr)
|
||||
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
if lb:
|
||||
userids = list(lb.keys())
|
||||
users = await self.bot.fetch_users(ids=userids)
|
||||
name_map = {user.id: user.display_name for user in users}
|
||||
parts = []
|
||||
for userid, total in lb.items():
|
||||
name = name_map.get(userid, str(userid))
|
||||
part = f"{name}: {total}"
|
||||
parts.append(part)
|
||||
lbstr = '; '.join(parts)
|
||||
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
|
||||
else:
|
||||
return f"{counter} {period.value[-1]} leaderboard is empty!"
|
||||
|
||||
# Misc actual counter commands
|
||||
# TODO: Factor this out to a different module...
|
||||
@commands.command()
|
||||
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'tea'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: tea'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def tealb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'coffee'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: coffee'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
|
||||
userid = int(ctx.author.id)
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
period, start_time = await self.parse_period(channelid, '')
|
||||
counter = 'water'
|
||||
|
||||
await self.add_to_counter(
|
||||
counter,
|
||||
userid,
|
||||
1,
|
||||
context='cmd: water'
|
||||
)
|
||||
lb = await self.leaderboard(counter, start_time=start_time)
|
||||
user_total = lb.get(userid, 0)
|
||||
total = sum(lb.values())
|
||||
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
|
||||
|
||||
@commands.command()
|
||||
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
|
||||
user = await ctx.channel.user()
|
||||
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
|
||||
|
||||
@commands.command()
|
||||
async def reload(self, ctx: commands.Context, *, args: str = ''):
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
return
|
||||
if not args:
|
||||
await ctx.reply("Full reload not implemented yet.")
|
||||
else:
|
||||
try:
|
||||
self.bot.reload_module(args)
|
||||
except Exception:
|
||||
logger.exception("Failed to reload")
|
||||
await ctx.reply("Failed to reload module! Check console~")
|
||||
else:
|
||||
await ctx.reply("Reloaded!")
|
||||
|
||||
48
src/modules/counters/data.py
Normal file
48
src/modules/counters/data.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class CounterData(Registry):
|
||||
class Counter(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE counters(
|
||||
counterid SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX counters_name ON counters (name);
|
||||
"""
|
||||
_tablename_ = 'counters'
|
||||
_cache_ = {}
|
||||
|
||||
counterid = Integer(primary=True)
|
||||
name = String()
|
||||
created_at = Timestamp()
|
||||
|
||||
class CounterEntry(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE counter_log(
|
||||
entryid SERIAL PRIMARY KEY,
|
||||
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
userid INTEGER NOT NULL,
|
||||
value INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
context_str TEXT
|
||||
);
|
||||
CREATE INDEX counter_log_counterid ON counter_log (counterid);
|
||||
"""
|
||||
_tablename_ = 'counter_log'
|
||||
_cache_ = {}
|
||||
|
||||
entryid = Integer(primary=True)
|
||||
counterid = Integer()
|
||||
userid = Integer()
|
||||
value = Integer()
|
||||
created_at = Timestamp()
|
||||
context_str = String()
|
||||
|
||||
|
||||
23
src/modules/counters/lion_cog.py
Normal file
23
src/modules/counters/lion_cog.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import UserInputError
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
from data.conditions import NULL
|
||||
|
||||
from . import logger
|
||||
from .data import CounterData
|
||||
|
||||
|
||||
class CounterCog(LionCog):
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
self.counter_cog = bot.crocbot.get_cog('CounterCog')
|
||||
9
src/modules/nowdoing/__init__.py
Normal file
9
src/modules/nowdoing/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import NowDoingCog
|
||||
|
||||
def prepare(bot):
|
||||
logger.info("Preparing the nowdoing module.")
|
||||
bot.add_cog(NowDoingCog(bot))
|
||||
253
src/modules/nowdoing/cog.py
Normal file
253
src/modules/nowdoing/cog.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from attr import dataclass
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot
|
||||
from meta.sockets import Channel, register_channel
|
||||
from utils.lib import strfdelta, utc_now
|
||||
from . import logger
|
||||
from .data import NowListData
|
||||
|
||||
|
||||
class NowDoingChannel(Channel):
|
||||
name = 'NowList'
|
||||
|
||||
def __init__(self, cog: 'NowDoingCog', **kwargs):
|
||||
self.cog = cog
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def on_connection(self, websocket, event):
|
||||
await super().on_connection(websocket, event)
|
||||
for task in self.cog.tasks.values():
|
||||
await self.send_set(*self.task_args(task), websocket=websocket)
|
||||
|
||||
async def send_test_set(self):
|
||||
tasks = [
|
||||
(0, 'Tester0', "Testing Tasklist", True),
|
||||
(1, 'Tester1', "Getting Confused", False),
|
||||
(2, "Tester2", "Generating Bugs", True),
|
||||
(3, "Tester3", "Fixing Bugs", False),
|
||||
(4, "Tester4", "Pushing the red button", False),
|
||||
]
|
||||
for task in tasks:
|
||||
await self.send_set(*task)
|
||||
|
||||
def task_args(self, task: NowListData.Task):
|
||||
return (
|
||||
task.userid,
|
||||
task.name,
|
||||
task.task,
|
||||
task.started_at.isoformat(),
|
||||
task.done_at.isoformat() if task.done_at else None,
|
||||
)
|
||||
|
||||
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "setTask",
|
||||
'args': {
|
||||
'userid': userid,
|
||||
'name': name,
|
||||
'task': task,
|
||||
'start_at': start_at,
|
||||
'end_at': end_at,
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
async def send_del(self, userid):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "delTask",
|
||||
'args': {
|
||||
'userid': userid,
|
||||
}
|
||||
})
|
||||
|
||||
async def send_clear(self):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': "clearTasks",
|
||||
'args': {
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class NowDoingCog(commands.Cog):
|
||||
def __init__(self, bot: CrocBot):
|
||||
self.bot = bot
|
||||
self.data = bot.data.load_registry(NowListData())
|
||||
self.channel = NowDoingChannel(self)
|
||||
register_channel(self.channel.name, self.channel)
|
||||
|
||||
# userid -> Task
|
||||
self.tasks: dict[int, NowListData.Task] = {}
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
await self.load_tasks()
|
||||
self.loaded.set()
|
||||
|
||||
async def ensure_loaded(self):
|
||||
"""
|
||||
Hack because lib devs decided to remove async cog loading.
|
||||
"""
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
@commands.Cog.event('event_ready') # type: ignore
|
||||
async def on_ready(self):
|
||||
await self.ensure_loaded()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
await self.ensure_loaded()
|
||||
return True
|
||||
|
||||
async def load_tasks(self):
|
||||
tasklist = await self.data.Task.fetch_where()
|
||||
tasks = {task.userid: task for task in tasklist}
|
||||
self.tasks = tasks
|
||||
logger.info(f"Loaded {len(tasks)} from database.")
|
||||
|
||||
@commands.command()
|
||||
async def test(self, ctx: commands.Context):
|
||||
if (ctx.author.is_broadcaster):
|
||||
# await self.channel.send_test_set()
|
||||
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
|
||||
# await ctx.send(str(list(self.tasks.items())[0]))
|
||||
await ctx.send(str(ctx.author.id))
|
||||
else:
|
||||
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
|
||||
|
||||
@commands.command(aliases=['task', 'check'])
|
||||
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
userid = int(ctx.author.id)
|
||||
if args:
|
||||
await self.data.Task.table.delete_where(userid=userid)
|
||||
task = await self.data.Task.create(
|
||||
userid=userid,
|
||||
name=ctx.author.display_name,
|
||||
task=args,
|
||||
started_at=utc_now(),
|
||||
)
|
||||
self.tasks[task.userid] = task
|
||||
await self.channel.send_set(*self.channel.task_args(task))
|
||||
await ctx.send(f"Updated your current task, good luck!")
|
||||
elif task := self.tasks.get(userid, None):
|
||||
if task.is_done:
|
||||
done_ago = strfdelta(utc_now() - task.done_at)
|
||||
await ctx.send(
|
||||
f"You finished '{task.task}' {done_ago} ago!"
|
||||
)
|
||||
else:
|
||||
started_ago = strfdelta(utc_now() - task.started_at)
|
||||
await ctx.send(
|
||||
f"You have been working on '{task.task}' for {started_ago}!"
|
||||
)
|
||||
else:
|
||||
await ctx.send(
|
||||
"You don't have a task on the tasklist! "
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command(name='next')
|
||||
async def nownext(self, ctx: commands.Context, *, args: Optional[str] = None):
|
||||
userid = int(ctx.author.id)
|
||||
task = self.tasks.get(userid, None)
|
||||
if args:
|
||||
if task:
|
||||
if not task.is_done:
|
||||
await task.update(done_at=utc_now())
|
||||
started_ago = strfdelta(task.done_at - task.started_at)
|
||||
prefix = (
|
||||
f"You worked on '{task.task}' for {started_ago}."
|
||||
)
|
||||
else:
|
||||
prefix = ""
|
||||
await self.data.Task.table.delete_where(userid=userid)
|
||||
task = await self.data.Task.create(
|
||||
userid=userid,
|
||||
name=ctx.author.display_name,
|
||||
task=args,
|
||||
started_at=utc_now(),
|
||||
)
|
||||
self.tasks[task.userid] = task
|
||||
await self.channel.send_set(*self.channel.task_args(task))
|
||||
await ctx.send(f"Next task set, good luck!" + ' ' + prefix)
|
||||
elif task:
|
||||
if task.is_done:
|
||||
done_ago = strfdelta(utc_now() - task.done_at)
|
||||
await ctx.send(
|
||||
f"You finished '{task.task}' {done_ago} ago!"
|
||||
)
|
||||
else:
|
||||
started_ago = strfdelta(utc_now() - task.started_at)
|
||||
await ctx.send(
|
||||
f"You have been working on '{task.task}' for {started_ago}!"
|
||||
)
|
||||
else:
|
||||
await ctx.send(
|
||||
"You don't have a task on the tasklist! "
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command()
|
||||
async def done(self, ctx: commands.Context):
|
||||
userid = int(ctx.author.id)
|
||||
if task := self.tasks.get(userid, None):
|
||||
if task.is_done:
|
||||
await ctx.send(
|
||||
f"You already finished '{task.task}'!"
|
||||
)
|
||||
else:
|
||||
await task.update(done_at=utc_now())
|
||||
started_ago = strfdelta(task.done_at - task.started_at)
|
||||
await self.channel.send_set(*self.channel.task_args(task))
|
||||
await ctx.send(
|
||||
f"Good job finishing '{task.task}'! "
|
||||
f"You worked on it for {started_ago}."
|
||||
)
|
||||
else:
|
||||
await ctx.send(
|
||||
"You don't have a task on the tasklist! "
|
||||
"Show what you are currently working on with, e.g. !now Reading notes"
|
||||
)
|
||||
|
||||
@commands.command()
|
||||
async def clear(self, ctx: commands.Context):
|
||||
userid = int(ctx.author.id)
|
||||
if task := self.tasks.pop(userid, None):
|
||||
await task.delete()
|
||||
await self.channel.send_del(userid)
|
||||
await ctx.send("Removed your task from the tasklist!")
|
||||
else:
|
||||
await ctx.send(
|
||||
"You don't have a task on the tasklist at the moment!"
|
||||
)
|
||||
|
||||
@commands.command()
|
||||
async def clearfor(self, ctx: commands.Context, user: twitchio.User):
|
||||
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
await self.channel.send_del(int(user.id))
|
||||
task = self.tasks.pop(int(user.id), None)
|
||||
if task is not None:
|
||||
await task.delete()
|
||||
await ctx.send("Cleared the task.")
|
||||
else:
|
||||
pass
|
||||
|
||||
@commands.command()
|
||||
async def clearall(self, ctx: commands.Context):
|
||||
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
await self.data.Task.table.delete_where()
|
||||
self.tasks.clear()
|
||||
await self.channel.send_clear()
|
||||
await ctx.send("Tasklist Cleared!")
|
||||
29
src/modules/nowdoing/data.py
Normal file
29
src/modules/nowdoing/data.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, Timestamp, String
|
||||
|
||||
|
||||
class NowListData(Registry):
|
||||
class Task(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE nowlist_tasks(
|
||||
userid BIGINT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
task TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL,
|
||||
done_at TIMESTAMPTZ
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'nowlist_tasks'
|
||||
_cache_ = {}
|
||||
|
||||
userid = Integer(primary=True)
|
||||
name = String()
|
||||
task = String()
|
||||
started_at = Timestamp()
|
||||
done_at = Timestamp()
|
||||
|
||||
@property
|
||||
def is_done(self):
|
||||
return self.done_at is not None
|
||||
@@ -13,6 +13,7 @@ from meta.sharding import THIS_SHARD
|
||||
from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel
|
||||
from utils.lib import utc_now
|
||||
from utils.ratelimits import limit_concurrency
|
||||
from meta.sockets import Channel, register_channel
|
||||
|
||||
from wards import low_management_ward
|
||||
|
||||
@@ -39,6 +40,37 @@ _param_options = {
|
||||
}
|
||||
|
||||
|
||||
class TimerChannel(Channel):
|
||||
name = 'Timer'
|
||||
|
||||
def __init__(self, cog: 'TimerCog', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.cog = cog
|
||||
|
||||
async def on_connection(self, websocket, event):
|
||||
await super().on_connection(websocket, event)
|
||||
timer = self.cog.get_channel_timer(1261999440160624734)
|
||||
if timer is not None:
|
||||
await self.send_set(
|
||||
timer.data.last_started,
|
||||
timer.data.focus_length,
|
||||
timer.data.break_length,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
|
||||
await self.send_event({
|
||||
'type': "DO",
|
||||
'method': 'setTimer',
|
||||
'args': {
|
||||
'start_at': start_at.isoformat(),
|
||||
'focus_length': focus_length,
|
||||
'break_length': break_length,
|
||||
'block_goal': goal,
|
||||
}
|
||||
}, websocket=websocket)
|
||||
|
||||
|
||||
class TimerCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
@@ -46,6 +78,9 @@ class TimerCog(LionCog):
|
||||
self.settings = TimerSettings()
|
||||
self.monitor = ComponentMonitor('TimerCog', self._monitor)
|
||||
|
||||
self.channel = TimerChannel(self)
|
||||
register_channel(self.channel.name, self.channel)
|
||||
|
||||
self.timer_options = TimerOptions()
|
||||
|
||||
self.ready = False
|
||||
@@ -1012,3 +1047,31 @@ class TimerCog(LionCog):
|
||||
ui = TimerConfigUI(self.bot, ctx.guild.id, ctx.channel.id)
|
||||
await ui.run(ctx.interaction)
|
||||
await ui.wait()
|
||||
|
||||
# ----- Hacky Stream commands -----
|
||||
@cmds.hybrid_group('streamtimer', with_app_command=True)
|
||||
async def streamtimer_group(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@streamtimer_group.command(
|
||||
name="update"
|
||||
)
|
||||
@low_management_ward
|
||||
async def streamtimer_update_cmd(self, ctx: LionContext,
|
||||
new_start: Optional[str] = None,
|
||||
new_goal: int = 12):
|
||||
timer = self.get_channel_timer(1261999440160624734)
|
||||
if timer is None:
|
||||
return
|
||||
if new_start:
|
||||
timezone = ctx.lmember.timezone
|
||||
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
|
||||
await timer.data.update(last_started=start_at)
|
||||
|
||||
await self.channel.send_set(
|
||||
timer.data.last_started,
|
||||
timer.data.focus_length,
|
||||
timer.data.break_length,
|
||||
goal=new_goal,
|
||||
)
|
||||
await ctx.reply("Stream Timer Updated")
|
||||
|
||||
@@ -195,9 +195,7 @@ class Timer:
|
||||
Uses voice channel member cache as source-of-truth.
|
||||
"""
|
||||
if (chan := self.channel):
|
||||
members = [
|
||||
member for member in chan.members if not member.bot and 1148167212901859328 in [role.id for role in member.roles]
|
||||
]
|
||||
members = [m for m in chan.members if not m.bot]
|
||||
else:
|
||||
members = []
|
||||
return members
|
||||
@@ -480,6 +478,7 @@ class Timer:
|
||||
if self.guild.voice_client:
|
||||
await self.guild.voice_client.disconnect(force=True)
|
||||
alert_file = focus_alert_path if stage.focused else break_alert_path
|
||||
|
||||
try:
|
||||
voice_client = await asyncio.wait_for(
|
||||
self.channel.connect(timeout=30, reconnect=False),
|
||||
@@ -613,7 +612,11 @@ class Timer:
|
||||
if render:
|
||||
try:
|
||||
card = await get_timer_card(self.bot, self, stage)
|
||||
await card.render()
|
||||
data = await card.render()
|
||||
import io
|
||||
with io.BytesIO(data) as buffer:
|
||||
with open(f"pomodoro_{self.data.channelid}.png", "wb") as f:
|
||||
f.write(buffer.getbuffer())
|
||||
rawargs['file'] = card.as_file(f"pomodoro_{self.data.channelid}.png")
|
||||
except RenderingException:
|
||||
pass
|
||||
@@ -841,8 +844,8 @@ class Timer:
|
||||
to_next_stage = (current.end - utc_now()).total_seconds()
|
||||
|
||||
# TODO: Consider request rate and load
|
||||
if to_next_stage > 5 * 60 - drift:
|
||||
time_to_sleep = 5 * 60
|
||||
if to_next_stage > 1 * 60 - drift:
|
||||
time_to_sleep = 1 * 60
|
||||
else:
|
||||
time_to_sleep = to_next_stage
|
||||
|
||||
|
||||
8
src/modules/shoutouts/__init__.py
Normal file
8
src/modules/shoutouts/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import ShoutoutCog
|
||||
|
||||
def prepare(bot):
|
||||
bot.add_cog(ShoutoutCog(bot))
|
||||
90
src/modules/shoutouts/cog.py
Normal file
90
src/modules/shoutouts/cog.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot
|
||||
from utils.lib import replace_multiple
|
||||
from . import logger
|
||||
from .data import ShoutoutData
|
||||
|
||||
|
||||
class ShoutoutCog(commands.Cog):
|
||||
# Future extension: channel defaults and config
|
||||
DEFAULT_SHOUTOUT = """
|
||||
We think that {name} is a great streamer and you should check them out \
|
||||
and drop a follow! \
|
||||
They {areorwere} streaming {game} at {channel}
|
||||
"""
|
||||
def __init__(self, bot: CrocBot):
|
||||
self.bot = bot
|
||||
self.data = bot.data.load_registry(ShoutoutData())
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
self.loaded.set()
|
||||
|
||||
async def ensure_loaded(self):
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
@commands.Cog.event('event_ready') # type: ignore
|
||||
async def on_ready(self):
|
||||
await self.ensure_loaded()
|
||||
|
||||
async def cog_check(self, ctx):
|
||||
await self.ensure_loaded()
|
||||
return True
|
||||
|
||||
async def format_shoutout(self, text: str, user: twitchio.User):
|
||||
channels = await self.bot.fetch_channels([user.id])
|
||||
if channels:
|
||||
channel = channels[0]
|
||||
game = channel.game_name or 'Unknown'
|
||||
else:
|
||||
game = 'Unknown'
|
||||
|
||||
streams = await self.bot.fetch_streams([user.id])
|
||||
live = bool(streams)
|
||||
|
||||
mapping = {
|
||||
'{name}': user.display_name,
|
||||
'{channel}': f"https://www.twitch.tv/{user.name}",
|
||||
'{game}': game,
|
||||
'{areorwere}': 'are' if live else 'were',
|
||||
}
|
||||
return replace_multiple(text, mapping)
|
||||
|
||||
@commands.command(aliases=['so'])
|
||||
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
|
||||
# Make sure caller is mod/broadcaster
|
||||
# Lookup custom shoutout for this user
|
||||
# If it exists use it, otherwise use default shoutout
|
||||
if (ctx.author.is_mod or ctx.author.is_broadcaster):
|
||||
data = await self.data.CustomShoutout.fetch(int(user.id))
|
||||
if data:
|
||||
shoutout = data.content
|
||||
else:
|
||||
shoutout = self.DEFAULT_SHOUTOUT
|
||||
formatted = await self.format_shoutout(shoutout, user)
|
||||
await ctx.reply(formatted)
|
||||
# TODO: How to /shoutout with lib?
|
||||
|
||||
@commands.command()
|
||||
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):
|
||||
# Make sure caller is mod/broadcaster/user themselves(?)
|
||||
# upsert/delete and insert (is upsert impl?)
|
||||
if (ctx.author.is_mod or ctx.author.is_broadcaster or int(ctx.author.id) == int(user.id)):
|
||||
await self.data.CustomShoutout.table.delete_where(userid=int(user.id))
|
||||
|
||||
if text and text.lower() not in ('reset', 'none'):
|
||||
await self.data.CustomShoutout.create(
|
||||
userid=int(user.id),
|
||||
content=text,
|
||||
)
|
||||
await ctx.reply("Custom shoutout updated!")
|
||||
else:
|
||||
await ctx.reply("Custom shoutout removed.")
|
||||
21
src/modules/shoutouts/data.py
Normal file
21
src/modules/shoutouts/data.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class ShoutoutData(Registry):
|
||||
class CustomShoutout(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE shoutouts(
|
||||
userid BIGINT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'shoutouts'
|
||||
_cache_ = {}
|
||||
|
||||
userid = Integer(primary=True)
|
||||
content = String()
|
||||
created_at = Timestamp()
|
||||
@@ -115,7 +115,7 @@ class AlertCog(LionCog):
|
||||
# Note we set page size to 100
|
||||
# So we should never get repeat or missed streams
|
||||
# Since we can request a max of 100 userids anyway.
|
||||
streaming[stream.user_id] = stream
|
||||
streaming[int(stream.user_id)] = stream
|
||||
|
||||
started = set(streaming.keys()).difference(self.live_streams.keys())
|
||||
ended = set(self.live_streams.keys()).difference(streaming.keys())
|
||||
@@ -123,9 +123,9 @@ class AlertCog(LionCog):
|
||||
for streamerid in started:
|
||||
stream = streaming[streamerid]
|
||||
stream_data = await self.data.Stream.create(
|
||||
streamerid=stream.user_id,
|
||||
streamerid=int(stream.user_id),
|
||||
start_at=stream.started_at,
|
||||
twitch_stream_id=stream.id,
|
||||
twitch_stream_id=int(stream.id),
|
||||
game_name=stream.game_name,
|
||||
title=stream.title,
|
||||
)
|
||||
@@ -143,7 +143,7 @@ class AlertCog(LionCog):
|
||||
|
||||
async def on_stream_start(self, stream_data):
|
||||
# Get channel subscriptions listening for this streamer
|
||||
uid = stream_data.streamerid
|
||||
uid = int(stream_data.streamerid)
|
||||
logger.info(f"Streamer <uid:{uid}> started streaming! {stream_data=}")
|
||||
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
|
||||
|
||||
@@ -197,7 +197,7 @@ class AlertCog(LionCog):
|
||||
return
|
||||
|
||||
# Build message
|
||||
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
|
||||
streamer = await self.data.Streamer.fetch(int(stream_data.streamerid))
|
||||
if not streamer:
|
||||
# Streamer was deleted while handling the alert
|
||||
# Just quietly ignore
|
||||
@@ -235,7 +235,7 @@ class AlertCog(LionCog):
|
||||
|
||||
# Store sent alert
|
||||
alert = await self.data.StreamAlert.create(
|
||||
streamid=stream_data.streamid,
|
||||
streamid=int(stream_data.streamid),
|
||||
subscriptionid=subscription.subscriptionid,
|
||||
sent_at=utc_now(),
|
||||
messageid=message.id
|
||||
@@ -246,7 +246,7 @@ class AlertCog(LionCog):
|
||||
|
||||
async def on_stream_end(self, stream_data):
|
||||
# Get channel subscriptions listening for this streamer
|
||||
uid = stream_data.streamerid
|
||||
uid = int(stream_data.streamerid)
|
||||
logger.info(f"Streamer <uid:{uid}> stopped streaming! {stream_data=}")
|
||||
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
|
||||
|
||||
@@ -269,8 +269,8 @@ class AlertCog(LionCog):
|
||||
async def sub_resolve(self, subscription, stream_data):
|
||||
# Check if there is a current active alert to resolve
|
||||
alerts = await self.data.StreamAlert.fetch_where(
|
||||
streamid=stream_data.streamid,
|
||||
subscriptionid=subscription.subscriptionid,
|
||||
streamid=int(stream_data.streamid),
|
||||
subscriptionid=int(subscription.subscriptionid),
|
||||
)
|
||||
if not alerts:
|
||||
logger.info(
|
||||
@@ -322,7 +322,7 @@ class AlertCog(LionCog):
|
||||
)
|
||||
else:
|
||||
# Edit message with custom arguments
|
||||
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
|
||||
streamer = await self.data.Streamer.fetch(int(stream_data.streamerid))
|
||||
formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer)
|
||||
formatted = await formatter(edit_setting.value)
|
||||
args = edit_setting.value_to_args(subscription.subscriptionid, formatted)
|
||||
@@ -400,7 +400,7 @@ class AlertCog(LionCog):
|
||||
|
||||
# Create streamer data if it doesn't already exist
|
||||
streamer_data = await self.data.Streamer.fetch_or_create(
|
||||
tw_user.id,
|
||||
int(tw_user.id),
|
||||
login_name=tw_user.login,
|
||||
display_name=tw_user.display_name,
|
||||
)
|
||||
@@ -418,8 +418,10 @@ class AlertCog(LionCog):
|
||||
self.watching[streamer_data.userid] = streamer_data
|
||||
|
||||
# Open AlertEditorUI for the new subscription
|
||||
# TODO
|
||||
await ctx.reply("StreamAlert Created.")
|
||||
ui = AlertEditorUI(bot=self.bot, sub_data=sub_data, callerid=ctx.author.id)
|
||||
await ui.run(ctx.interaction)
|
||||
await ui.wait()
|
||||
|
||||
async def alert_acmpl(self, interaction: discord.Interaction, partial: str):
|
||||
if not interaction.guild:
|
||||
|
||||
8
src/modules/tagstrings/__init__.py
Normal file
8
src/modules/tagstrings/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .cog import TagCog
|
||||
|
||||
def prepare(bot):
|
||||
bot.add_cog(TagCog(bot))
|
||||
152
src/modules/tagstrings/cog.py
Normal file
152
src/modules/tagstrings/cog.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import difflib
|
||||
|
||||
import twitchio
|
||||
from twitchio.ext import commands
|
||||
|
||||
from meta import CrocBot
|
||||
from utils.lib import utc_now
|
||||
from . import logger
|
||||
from .data import TagData
|
||||
|
||||
|
||||
class TagCog(commands.Cog):
|
||||
def __init__(self, bot: CrocBot):
|
||||
self.bot = bot
|
||||
self.data = bot.data.load_registry(TagData())
|
||||
|
||||
self.loaded = asyncio.Event()
|
||||
|
||||
# Cache of channel tags, channelid -> name.lower() -> Tag
|
||||
self.tags: dict[int, dict[str, TagData.Tag]] = {}
|
||||
|
||||
async def load_tags(self):
|
||||
tags = defaultdict(dict)
|
||||
|
||||
rows = await self.data.Tag.fetch_where()
|
||||
for row in rows:
|
||||
tags[row.channelid][row.name.lower()] = row
|
||||
|
||||
self.tags.clear()
|
||||
self.tags.update(tags)
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
await self.load_tags()
|
||||
self.loaded.set()
|
||||
|
||||
async def ensure_loaded(self):
|
||||
if not self.loaded.is_set():
|
||||
await self.cog_load()
|
||||
|
||||
@commands.Cog.event('event_ready')
|
||||
async def on_ready(self):
|
||||
await self.ensure_loaded()
|
||||
|
||||
# API
|
||||
|
||||
async def create_tag(self, channelid: int, name: str, content: str, created_by: int):
|
||||
"""
|
||||
Create a new Tag with the given parameters.
|
||||
|
||||
If the tag already exists, will raise (TODO)
|
||||
"""
|
||||
row = await self.data.Tag.create(
|
||||
channelid=channelid,
|
||||
name=name,
|
||||
content=content,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
if (chantags := self.tags.get(channelid, None)) is None:
|
||||
chantags = self.tags[channelid] = {}
|
||||
chantags[name.lower()] = row
|
||||
|
||||
logger.info(f"Created Tag: {row!r}")
|
||||
|
||||
return row
|
||||
|
||||
# Commands
|
||||
|
||||
@commands.command()
|
||||
async def edittag(self, ctx: commands.Context, tagname: str, *, content: str):
|
||||
"""
|
||||
Create or edit a tag.
|
||||
"""
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
userid = int(ctx.author.id)
|
||||
|
||||
# Fetch the tag if it exists
|
||||
tag = self.tags.get(channelid, {}).get(tagname.lower(), None)
|
||||
|
||||
if tag is None:
|
||||
# Create new tag
|
||||
tag = await self.create_tag(
|
||||
channelid,
|
||||
tagname,
|
||||
content,
|
||||
userid
|
||||
)
|
||||
await ctx.reply(f"Tag '{tagname}' created as #{tag.tagid}!")
|
||||
else:
|
||||
# Edit existing tag
|
||||
if not (ctx.author.is_mod or ctx.author.is_broadcaster or userid == tag.created_by):
|
||||
await ctx.reply("You can't edit this tag!")
|
||||
return
|
||||
|
||||
await tag.update(
|
||||
content=content,
|
||||
updated_at=utc_now()
|
||||
)
|
||||
|
||||
await ctx.reply(f"Updated '{tag.name}'")
|
||||
|
||||
@commands.command()
|
||||
async def deltag(self, ctx: commands.Context, tagname: str):
|
||||
if ctx.author.is_broadcaster or ctx.author.is_mod:
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
tag = self.tags.get(channelid, {}).get(tagname.lower(), None)
|
||||
if tag is None:
|
||||
await ctx.reply(f"Couldn't find '{tagname}' to delete!")
|
||||
else:
|
||||
self.tags[channelid].pop(tag.name.lower())
|
||||
await tag.delete()
|
||||
await ctx.reply(f"Deleted '{tag.name}'")
|
||||
|
||||
@commands.command()
|
||||
async def tag(self, ctx: commands.Context, tagname: str):
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
tags = self.tags.get(channelid, {})
|
||||
if (tag := tags.get(tagname.lower(), None)) is None:
|
||||
# Search for closest match
|
||||
|
||||
matches = difflib.get_close_matches(tagname.lower(), tags.keys(), n=2)
|
||||
matchstr = "'{}'".format("' or '".join(matches)) if matches else None
|
||||
suffix = f"Did you mean {matchstr}?" if matches else ""
|
||||
|
||||
await ctx.reply(f"Couldn't find tag '{tagname}'! {suffix}")
|
||||
return
|
||||
await ctx.reply(tag.content)
|
||||
|
||||
@commands.command(name='tags')
|
||||
async def cmd_tags(self, ctx: commands.Context, *, searchstr: str = ''):
|
||||
"""
|
||||
List the tags available in the current channel.
|
||||
"""
|
||||
channelid = int((await ctx.channel.user()).id)
|
||||
tag_names = [tag.name for tag in self.tags.get(channelid, {}).values()]
|
||||
matching = [name for name in tag_names if searchstr.lower() in name.lower()]
|
||||
tagstr = ', '.join(matching)
|
||||
|
||||
if searchstr:
|
||||
if matching:
|
||||
await ctx.reply(f"Matching tags: {tagstr}")
|
||||
else:
|
||||
await ctx.reply(f"No tags matching '{searchstr}'")
|
||||
else:
|
||||
if matching:
|
||||
await ctx.reply(f"Available tags: {tagstr}")
|
||||
else:
|
||||
await ctx.reply("No tags set up on this channel!")
|
||||
30
src/modules/tagstrings/data.py
Normal file
30
src/modules/tagstrings/data.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from data import Registry, RowModel
|
||||
from data.columns import Integer, String, Timestamp
|
||||
|
||||
|
||||
class TagData(Registry):
|
||||
class Tag(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE channel_tags(
|
||||
tagid SERIAL PRIMARY KEY,
|
||||
channelid BIGINT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_by BIGINT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ
|
||||
);
|
||||
CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name);
|
||||
"""
|
||||
_tablename_ = 'channel_tags'
|
||||
_cache_ ={}
|
||||
|
||||
tagid = Integer(primary=True)
|
||||
channelid = Integer()
|
||||
name = String()
|
||||
content = String()
|
||||
created_by = Integer()
|
||||
created_at = Timestamp()
|
||||
updated_at = Timestamp()
|
||||
Reference in New Issue
Block a user