Initial merger with Twitch interface.

This commit is contained in:
2024-08-27 17:41:33 +10:00
parent 7e6dcb006f
commit d10fd2fc1d
28 changed files with 1309 additions and 40 deletions

Binary file not shown.

Binary file not shown.

BIN
assets/pomodoro/chime.mp3 Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1406,6 +1406,65 @@ CREATE TABLE stream_alerts(
-- }}} -- }}}
-- Nowlist {{{
CREATE TABLE nowlist_tasks(
userid BIGINT PRIMARY KEY,
name TEXT NOT NULL,
task TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL,
done_at TIMESTAMPTZ
);
-- }}}
-- Shoutouts {{{
CREATE TABLE shoutouts(
userid BIGINT PRIMARY KEY,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
-- }}}
-- Counters {{{
CREATE TABLE counters(
counterid SERIAL PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX counters_name ON counters (name);
CREATE TABLE counter_log(
entryid SERIAL PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
userid INTEGER NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
context_str TEXT
);
CREATE INDEX counter_log_counterid ON counter_log (counterid);
-- }}}
-- Tags {{{
CREATE TABLE channel_tags(
tagid SERIAL PRIMARY KEY,
channelid BIGINT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_by BIGINT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name);
-- }}}
-- Analytics Data {{{ -- Analytics Data {{{
CREATE SCHEMA "analytics"; CREATE SCHEMA "analytics";

View File

@@ -1,15 +1,17 @@
import asyncio import asyncio
from contextlib import AsyncExitStack
import logging import logging
import websockets
import aiohttp import aiohttp
import discord import discord
from discord.ext import commands from discord.ext import commands
from meta import LionBot, conf, sharding, appname, shard_talk from meta import CrocBot, LionBot, conf, sharding, appname, shard_talk, sockets, args
from meta.app import shardname from meta.app import shardname
from meta.logger import log_context, log_action_stack, setup_main_logger from meta.logger import log_context, log_action_stack, setup_main_logger
from meta.context import ctx_bot from meta.context import ctx_bot
from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus, SystemMonitor
from data import Database from data import Database
@@ -58,18 +60,28 @@ async def main():
intents.message_content = True intents.message_content = True
intents.presences = False intents.presences = False
async with db.open(): system_monitor = SystemMonitor()
async with AsyncExitStack() as stack:
await stack.enter_async_context(db.open())
version = await db.version() version = await db.version()
if version.version != DATA_VERSION: if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error) logger.critical(error)
raise RuntimeError(error) raise RuntimeError(error)
system_monitor.add_component(ComponentMonitor('Database', _data_monitor))
translator = LeoBabel() translator = LeoBabel()
ctx_translator.set(translator) ctx_translator.set(translator)
async with aiohttp.ClientSession() as session: session = await stack.enter_async_context(aiohttp.ClientSession())
async with LionBot( await stack.enter_async_context(
websockets.serve(sockets.root_handler, '', conf.wserver['port'])
)
lionbot = await stack.enter_async_context(
LionBot(
command_prefix='!', command_prefix='!',
intents=intents, intents=intents,
appname=appname, appname=appname,
@@ -91,11 +103,30 @@ async def main():
proxy=conf.bot.get('proxy', None), proxy=conf.bot.get('proxy', None),
translator=translator, translator=translator,
chunk_guilds_at_startup=False, chunk_guilds_at_startup=False,
) as lionbot: system_monitor=system_monitor,
ctx_bot.set(lionbot)
lionbot.system_monitor.add_component(
ComponentMonitor('Database', _data_monitor)
) )
)
crocbot = CrocBot(
config=conf,
data=db,
prefix='!',
initial_channels=conf.croccy.getlist('initial_channels'),
token=conf.croccy['token'],
lionbot=lionbot
)
lionbot.crocbot = crocbot
crocbot.load_module('modules')
crocstart = asyncio.create_task(start_croccy(crocbot))
lionstart = asyncio.create_task(start_lion(lionbot))
await asyncio.wait((crocstart, lionstart), return_when=asyncio.FIRST_COMPLETED)
crocstart.cancel()
lionstart.cancel()
async def start_lion(lionbot):
ctx_bot.set(lionbot)
try: try:
log_context.set(f"APP: {appname}") log_context.set(f"APP: {appname}")
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'}) logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
@@ -104,6 +135,18 @@ async def main():
log_context.set(f"APP: {appname}") log_context.set(f"APP: {appname}")
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True) logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
async def start_croccy(crocbot):
try:
log_context.set(f"APP: {appname}-croccy")
logger.info("Starting Twitch bot.", extra={'action': 'Starting'})
await crocbot.start()
except asyncio.CancelledError:
logger.info("Croccybot shutting down gracefully.")
except Exception:
logger.exception("Croccybot shutting down ungracefully.")
finally:
await crocbot.close()
def _main(): def _main():
from signal import SIGINT, SIGTERM from signal import SIGINT, SIGTERM

32
src/meta/CrocBot.py Normal file
View File

@@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
import logging
from twitchio.ext import commands
from twitchio.ext import pubsub
from data import Database
from .config import Conf
if TYPE_CHECKING:
from .LionBot import LionBot
logger = logging.getLogger(__name__)
class CrocBot(commands.Bot):
def __init__(self, *args,
config: Conf,
data: Database,
lionbot: 'LionBot', **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data = data
self.pubsub = pubsub.PubSubPool(self)
self.lionbot = lionbot
async def event_ready(self):
logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")

View File

@@ -56,7 +56,9 @@ class LionBot(Bot):
def __init__( def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf, translator: LeoBabel, self, *args, appname: str, shardname: str, db: Database, config: Conf, translator: LeoBabel,
initial_extensions: List[str], web_client: ClientSession, app_ipc, initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], **kwargs testing_guilds: List[int] = [],
system_monitor: Optional[SystemMonitor] = None,
**kwargs
): ):
kwargs.setdefault('tree_cls', LionTree) kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -71,7 +73,7 @@ class LionBot(Bot):
self.app_ipc = app_ipc self.app_ipc = app_ipc
self.translator = translator self.translator = translator
self.system_monitor = SystemMonitor() self.system_monitor = system_monitor or SystemMonitor()
self.monitor = ComponentMonitor('LionBot', self._monitor_status) self.monitor = ComponentMonitor('LionBot', self._monitor_status)
self.system_monitor.add_component(self.monitor) self.system_monitor.add_component(self.monitor)

View File

@@ -3,6 +3,8 @@ from .LionCog import LionCog
from .LionContext import LionContext from .LionContext import LionContext
from .LionTree import LionTree from .LionTree import LionTree
from .CrocBot import CrocBot
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji from .config import conf, configEmoji
from .args import args from .args import args
@@ -10,6 +12,7 @@ from .app import appname, shard_talk, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot from .context import context, ctx_bot
from . import sockets
from . import sharding from . import sharding
from . import logger from . import logger
from . import app from . import app

68
src/meta/sockets.py Normal file
View File

@@ -0,0 +1,68 @@
from abc import ABC
from collections import defaultdict
import json
from typing import Any
import logging
import websockets
logger = logging.getLogger(__name__)
class Channel(ABC):
"""
A channel is a stateful connection handler for a group of connected websockets.
"""
name = "Root Channel"
def __init__(self, **kwargs):
self.connections = set()
@property
def empty(self):
return not self.connections
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
self.connections.add(websocket)
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
self.connections.discard(websocket)
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
raise NotImplementedError
async def send_event(self, event, websocket=None):
message = json.dumps(event)
if not websocket:
for ws in self.connections:
await ws.send(message)
else:
await websocket.send(message)
channels = {}
def register_channel(name, channel: Channel):
channels[name] = channel
async def root_handler(websocket: websockets.WebSocketServerProtocol):
message = await websocket.recv()
event = json.loads(message)
if event.get('type', None) != 'init':
raise ValueError("Received Websocket connection with no init.")
if (channel_name := event.get('channel', None)) not in channels:
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
channel = channels[channel_name]
try:
await channel.on_connection(websocket, event)
async for message in websocket:
await channel.handle_message(websocket, message)
finally:
await channel.del_connection(websocket)

View File

@@ -1,6 +1,6 @@
this_package = 'modules' this_package = 'modules'
active = [ active_discord = [
'.sysadmin', '.sysadmin',
'.config', '.config',
'.user_config', '.user_config',
@@ -28,7 +28,18 @@ active = [
'.test', '.test',
] ]
active_twitch = [
'.nowdoing',
'.shoutouts',
'.counters',
'.tagstrings',
]
def prepare(bot):
for ext in active_twitch:
bot.load_module(this_package + ext)
async def setup(bot): async def setup(bot):
for ext in active: for ext in active_discord:
await bot.load_extension(ext, package=this_package) await bot.load_extension(ext, package=this_package)

View File

@@ -0,0 +1,13 @@
import logging
logger = logging.getLogger(__name__)
from .cog import CounterCog
def prepare(bot):
bot.add_cog(CounterCog(bot))
async def setup(bot):
from .lion_cog import CounterCog
await bot.add_cog(CounterCog(bot))

299
src/modules/counters/cog.py Normal file
View File

@@ -0,0 +1,299 @@
import asyncio
from enum import Enum
from typing import Optional
from datetime import timedelta
import twitchio
from twitchio.ext import commands
from data.queries import ORDER
from meta import CrocBot
from utils.lib import utc_now
from . import logger
from .data import CounterData
class PERIOD(Enum):
ALL = ('', 'all', 'all-time')
STREAM = ('this stream', 'stream',)
DAY = ('today', 'd', 'day', 'today', 'daily')
WEEK = ('this week', 'w', 'week', 'weekly')
MONTH = ('this month', 'm', 'mo', 'month', 'monthly')
YEAR = ('this year', 'y', 'year', 'yearly')
class CounterCog(commands.Cog):
def __init__(self, bot: CrocBot):
self.bot = bot
self.data = bot.data.load_registry(CounterData())
self.loaded = asyncio.Event()
# Cache of counter names -> rows
self.counters = {}
async def cog_load(self):
await self.data.init()
self.loaded.set()
async def load_counters(self):
"""
Initialise counter name cache.
"""
rows = await self.data.Counter.fetch_where()
self.counters = {row.name: row for row in rows}
logger.info(
f"Loaded {len(self.counters)} counters."
)
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
# Counters API
async def fetch_counter(self, counter: str) -> CounterData.Counter:
"""
Fetches the Counter with the given name,
or creates it if it doesn't exist.
"""
if (row := self.counters.get(counter, None)) is None:
row = await self.data.Counter.fetch_or_create(name=counter)
self.counters[counter] = row
return row
async def delete_counter(self, counter: str):
self.counters.pop(counter, None)
await self.data.Counter.table.delete_where(name=counter)
async def reset_counter(self, counter: str):
row = self.counters.get(counter, None)
if row:
await self.data.CounterEntry.table.delete_where(counterid=row.counterid)
async def add_to_counter(self, counter: str, userid: int, value: int, context: Optional[str]=None):
row = await self.fetch_counter(counter)
return await self.data.CounterEntry.create(
counterid=row.counterid,
userid=userid,
value=value,
context_str=context
)
async def leaderboard(self, counter: str, start_time=None):
row = await self.fetch_counter(counter)
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
query.select('userid', user_total="SUM(value)")
query.group_by('userid')
query.order_by('user_total', ORDER.DESC)
if start_time is not None:
query.where(self.data.CounterEntry.created_at >= start_time)
query.with_no_adapter()
results = await query
lb = {result['userid']: result['user_total'] for result in results}
return lb
async def personal_total(self, counter: str, userid: int):
row = await self.fetch_counter(counter)
query = self.data.CounterEntry.table.select_where(counterid=row.counterid, userid=userid)
query.select(user_total="SUM(value)")
query.with_no_adapter()
results = await query
return results[0]['user_total'] if results else 0
async def totals(self, counter):
row = await self.fetch_counter(counter)
query = self.data.CounterEntry.table.select_where(counterid=row.counterid)
query.select(counter_total="SUM(value)")
query.with_no_adapter()
results = await query
return results[0]['counter_total'] if results else 0
# Counters commands
@commands.command()
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
return
name = name.lower()
if subcmd is None or subcmd == 'show':
# Show
total = await self.totals(name)
await ctx.reply(f"'{name}' counter is: {total}")
elif subcmd == 'add':
if args is None:
value = 1
else:
try:
value = int(args)
except ValueError:
await ctx.reply(f"Could not parse value to add.")
return
await self.add_to_counter(
name,
int(ctx.author.id),
value,
context='cmd: counter add'
)
total = await self.totals(name)
await ctx.reply(f"'{name}' counter is now: {total}")
elif subcmd == 'lb':
user = await ctx.channel.user()
lbstr = await self.formatted_lb(name, args or '', int(user.id))
await ctx.reply(lbstr)
elif subcmd == 'clear':
await self.reset_counter(name)
await ctx.reply(f"'{name}' counter reset.")
else:
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear'.")
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM):
if periodstr:
period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
if period is None:
raise ValueError("Invalid period string provided")
else:
period = default
now = utc_now()
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if period is PERIOD.ALL:
start_time = None
elif period is PERIOD.STREAM:
streams = await self.bot.fetch_streams(user_ids=[userid])
if streams:
stream = streams[0]
start_time = stream.started_at
else:
period = PERIOD.ALL
start_time = None
elif period is PERIOD.DAY:
start_time = today
elif period is PERIOD.WEEK:
start_time = today - timedelta(days=today.weekday())
elif period is PERIOD.MONTH:
start_time = today.replace(day=1)
elif period is PERIOD.YEAR:
start_time = today.replace(day=1, month=1)
else:
period = PERIOD.ALL
start_time = None
return (period, start_time)
async def formatted_lb(self, counter: str, periodstr: str, channelid: int):
period, start_time = await self.parse_period(channelid, periodstr)
lb = await self.leaderboard(counter, start_time=start_time)
if lb:
userids = list(lb.keys())
users = await self.bot.fetch_users(ids=userids)
name_map = {user.id: user.display_name for user in users}
parts = []
for userid, total in lb.items():
name = name_map.get(userid, str(userid))
part = f"{name}: {total}"
parts.append(part)
lbstr = '; '.join(parts)
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}"
else:
return f"{counter} {period.value[-1]} leaderboard is empty!"
# Misc actual counter commands
# TODO: Factor this out to a different module...
@commands.command()
async def tea(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'tea'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: tea'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your tea! We have had {total} cups of tea {period.value[0]}.")
@commands.command()
async def tealb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('tea', args, int(user.id)))
@commands.command()
async def coffee(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'coffee'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: coffee'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Enjoy your coffee! We have had {total} cups of coffee {period.value[0]}.")
@commands.command()
async def coffeelb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('coffee', args, int(user.id)))
@commands.command()
async def water(self, ctx: commands.Context, *, args: Optional[str]=None):
userid = int(ctx.author.id)
channelid = int((await ctx.channel.user()).id)
period, start_time = await self.parse_period(channelid, '')
counter = 'water'
await self.add_to_counter(
counter,
userid,
1,
context='cmd: water'
)
lb = await self.leaderboard(counter, start_time=start_time)
user_total = lb.get(userid, 0)
total = sum(lb.values())
await ctx.reply(f"Good job hydrating! We have had {total} cups of water {period.value[0]}.")
@commands.command()
async def waterlb(self, ctx: commands.Context, *, args: str = ''):
user = await ctx.channel.user()
await ctx.reply(await self.formatted_lb('water', args, int(user.id)))
@commands.command()
async def reload(self, ctx: commands.Context, *, args: str = ''):
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
return
if not args:
await ctx.reply("Full reload not implemented yet.")
else:
try:
self.bot.reload_module(args)
except Exception:
logger.exception("Failed to reload")
await ctx.reply("Failed to reload module! Check console~")
else:
await ctx.reply("Reloaded!")

View File

@@ -0,0 +1,48 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class CounterData(Registry):
class Counter(RowModel):
"""
Schema
------
CREATE TABLE counters(
counterid SERIAL PRIMARY KEY,
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX counters_name ON counters (name);
"""
_tablename_ = 'counters'
_cache_ = {}
counterid = Integer(primary=True)
name = String()
created_at = Timestamp()
class CounterEntry(RowModel):
"""
Schema
------
CREATE TABLE counter_log(
entryid SERIAL PRIMARY KEY,
counterid INTEGER NOT NULL REFERENCES counters (counterid) ON UPDATE CASCADE ON DELETE CASCADE,
userid INTEGER NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
context_str TEXT
);
CREATE INDEX counter_log_counterid ON counter_log (counterid);
"""
_tablename_ = 'counter_log'
_cache_ = {}
entryid = Integer(primary=True)
counterid = Integer()
userid = Integer()
value = Integer()
created_at = Timestamp()
context_str = String()

View File

@@ -0,0 +1,23 @@
import asyncio
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from meta.logger import log_wrap
from utils.lib import utc_now
from data.conditions import NULL
from . import logger
from .data import CounterData
class CounterCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.counter_cog = bot.crocbot.get_cog('CounterCog')

View File

@@ -0,0 +1,9 @@
import logging
logger = logging.getLogger(__name__)
from .cog import NowDoingCog
def prepare(bot):
logger.info("Preparing the nowdoing module.")
bot.add_cog(NowDoingCog(bot))

253
src/modules/nowdoing/cog.py Normal file
View File

@@ -0,0 +1,253 @@
import asyncio
import datetime as dt
import json
import os
from typing import Optional
from attr import dataclass
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from meta.sockets import Channel, register_channel
from utils.lib import strfdelta, utc_now
from . import logger
from .data import NowListData
class NowDoingChannel(Channel):
name = 'NowList'
def __init__(self, cog: 'NowDoingCog', **kwargs):
self.cog = cog
super().__init__(**kwargs)
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
for task in self.cog.tasks.values():
await self.send_set(*self.task_args(task), websocket=websocket)
async def send_test_set(self):
tasks = [
(0, 'Tester0', "Testing Tasklist", True),
(1, 'Tester1', "Getting Confused", False),
(2, "Tester2", "Generating Bugs", True),
(3, "Tester3", "Fixing Bugs", False),
(4, "Tester4", "Pushing the red button", False),
]
for task in tasks:
await self.send_set(*task)
def task_args(self, task: NowListData.Task):
return (
task.userid,
task.name,
task.task,
task.started_at.isoformat(),
task.done_at.isoformat() if task.done_at else None,
)
async def send_set(self, userid, name, task, start_at, end_at, websocket=None):
await self.send_event({
'type': "DO",
'method': "setTask",
'args': {
'userid': userid,
'name': name,
'task': task,
'start_at': start_at,
'end_at': end_at,
}
}, websocket=websocket)
async def send_del(self, userid):
await self.send_event({
'type': "DO",
'method': "delTask",
'args': {
'userid': userid,
}
})
async def send_clear(self):
await self.send_event({
'type': "DO",
'method': "clearTasks",
'args': {
}
})
class NowDoingCog(commands.Cog):
def __init__(self, bot: CrocBot):
self.bot = bot
self.data = bot.data.load_registry(NowListData())
self.channel = NowDoingChannel(self)
register_channel(self.channel.name, self.channel)
# userid -> Task
self.tasks: dict[int, NowListData.Task] = {}
self.loaded = asyncio.Event()
async def cog_load(self):
await self.data.init()
await self.load_tasks()
self.loaded.set()
async def ensure_loaded(self):
"""
Hack because lib devs decided to remove async cog loading.
"""
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
async def load_tasks(self):
tasklist = await self.data.Task.fetch_where()
tasks = {task.userid: task for task in tasklist}
self.tasks = tasks
logger.info(f"Loaded {len(tasks)} from database.")
@commands.command()
async def test(self, ctx: commands.Context):
if (ctx.author.is_broadcaster):
# await self.channel.send_test_set()
# await ctx.send(f"Hello {ctx.author.name}! This command does something, we aren't sure what yet.")
# await ctx.send(str(list(self.tasks.items())[0]))
await ctx.send(str(ctx.author.id))
else:
await ctx.send(f"Hello {ctx.author.name}! I don't think you have permission to test that.")
@commands.command(aliases=['task', 'check'])
async def now(self, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
if args:
await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create(
userid=userid,
name=ctx.author.display_name,
task=args,
started_at=utc_now(),
)
self.tasks[task.userid] = task
await self.channel.send_set(*self.channel.task_args(task))
await ctx.send(f"Updated your current task, good luck!")
elif task := self.tasks.get(userid, None):
if task.is_done:
done_ago = strfdelta(utc_now() - task.done_at)
await ctx.send(
f"You finished '{task.task}' {done_ago} ago!"
)
else:
started_ago = strfdelta(utc_now() - task.started_at)
await ctx.send(
f"You have been working on '{task.task}' for {started_ago}!"
)
else:
await ctx.send(
"You don't have a task on the tasklist! "
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command(name='next')
async def nownext(self, ctx: commands.Context, *, args: Optional[str] = None):
userid = int(ctx.author.id)
task = self.tasks.get(userid, None)
if args:
if task:
if not task.is_done:
await task.update(done_at=utc_now())
started_ago = strfdelta(task.done_at - task.started_at)
prefix = (
f"You worked on '{task.task}' for {started_ago}."
)
else:
prefix = ""
await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create(
userid=userid,
name=ctx.author.display_name,
task=args,
started_at=utc_now(),
)
self.tasks[task.userid] = task
await self.channel.send_set(*self.channel.task_args(task))
await ctx.send(f"Next task set, good luck!" + ' ' + prefix)
elif task:
if task.is_done:
done_ago = strfdelta(utc_now() - task.done_at)
await ctx.send(
f"You finished '{task.task}' {done_ago} ago!"
)
else:
started_ago = strfdelta(utc_now() - task.started_at)
await ctx.send(
f"You have been working on '{task.task}' for {started_ago}!"
)
else:
await ctx.send(
"You don't have a task on the tasklist! "
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command()
async def done(self, ctx: commands.Context):
userid = int(ctx.author.id)
if task := self.tasks.get(userid, None):
if task.is_done:
await ctx.send(
f"You already finished '{task.task}'!"
)
else:
await task.update(done_at=utc_now())
started_ago = strfdelta(task.done_at - task.started_at)
await self.channel.send_set(*self.channel.task_args(task))
await ctx.send(
f"Good job finishing '{task.task}'! "
f"You worked on it for {started_ago}."
)
else:
await ctx.send(
"You don't have a task on the tasklist! "
"Show what you are currently working on with, e.g. !now Reading notes"
)
@commands.command()
async def clear(self, ctx: commands.Context):
userid = int(ctx.author.id)
if task := self.tasks.pop(userid, None):
await task.delete()
await self.channel.send_del(userid)
await ctx.send("Removed your task from the tasklist!")
else:
await ctx.send(
"You don't have a task on the tasklist at the moment!"
)
@commands.command()
async def clearfor(self, ctx: commands.Context, user: twitchio.User):
if (ctx.author.is_mod or ctx.author.is_broadcaster):
await self.channel.send_del(int(user.id))
task = self.tasks.pop(int(user.id), None)
if task is not None:
await task.delete()
await ctx.send("Cleared the task.")
else:
pass
@commands.command()
async def clearall(self, ctx: commands.Context):
if (ctx.author.is_mod or ctx.author.is_broadcaster):
await self.data.Task.table.delete_where()
self.tasks.clear()
await self.channel.send_clear()
await ctx.send("Tasklist Cleared!")

View File

@@ -0,0 +1,29 @@
from data import Registry, RowModel
from data.columns import Integer, Timestamp, String
class NowListData(Registry):
class Task(RowModel):
"""
Schema
------
CREATE TABLE nowlist_tasks(
userid BIGINT PRIMARY KEY,
name TEXT NOT NULL,
task TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL,
done_at TIMESTAMPTZ
);
"""
_tablename_ = 'nowlist_tasks'
_cache_ = {}
userid = Integer(primary=True)
name = String()
task = String()
started_at = Timestamp()
done_at = Timestamp()
@property
def is_done(self):
return self.done_at is not None

View File

@@ -13,6 +13,7 @@ from meta.sharding import THIS_SHARD
from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel
from utils.lib import utc_now from utils.lib import utc_now
from utils.ratelimits import limit_concurrency from utils.ratelimits import limit_concurrency
from meta.sockets import Channel, register_channel
from wards import low_management_ward from wards import low_management_ward
@@ -39,6 +40,37 @@ _param_options = {
} }
class TimerChannel(Channel):
name = 'Timer'
def __init__(self, cog: 'TimerCog', **kwargs):
super().__init__(**kwargs)
self.cog = cog
async def on_connection(self, websocket, event):
await super().on_connection(websocket, event)
timer = self.cog.get_channel_timer(1261999440160624734)
if timer is not None:
await self.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
websocket=websocket,
)
async def send_set(self, start_at, focus_length, break_length, goal=12, websocket=None):
await self.send_event({
'type': "DO",
'method': 'setTimer',
'args': {
'start_at': start_at.isoformat(),
'focus_length': focus_length,
'break_length': break_length,
'block_goal': goal,
}
}, websocket=websocket)
class TimerCog(LionCog): class TimerCog(LionCog):
def __init__(self, bot: LionBot): def __init__(self, bot: LionBot):
self.bot = bot self.bot = bot
@@ -46,6 +78,9 @@ class TimerCog(LionCog):
self.settings = TimerSettings() self.settings = TimerSettings()
self.monitor = ComponentMonitor('TimerCog', self._monitor) self.monitor = ComponentMonitor('TimerCog', self._monitor)
self.channel = TimerChannel(self)
register_channel(self.channel.name, self.channel)
self.timer_options = TimerOptions() self.timer_options = TimerOptions()
self.ready = False self.ready = False
@@ -1012,3 +1047,31 @@ class TimerCog(LionCog):
ui = TimerConfigUI(self.bot, ctx.guild.id, ctx.channel.id) ui = TimerConfigUI(self.bot, ctx.guild.id, ctx.channel.id)
await ui.run(ctx.interaction) await ui.run(ctx.interaction)
await ui.wait() await ui.wait()
# ----- Hacky Stream commands -----
@cmds.hybrid_group('streamtimer', with_app_command=True)
async def streamtimer_group(self, ctx: LionContext):
...
@streamtimer_group.command(
name="update"
)
@low_management_ward
async def streamtimer_update_cmd(self, ctx: LionContext,
new_start: Optional[str] = None,
new_goal: int = 12):
timer = self.get_channel_timer(1261999440160624734)
if timer is None:
return
if new_start:
timezone = ctx.lmember.timezone
start_at = await self.bot.get_cog('Reminders').parse_time_static(new_start, timezone)
await timer.data.update(last_started=start_at)
await self.channel.send_set(
timer.data.last_started,
timer.data.focus_length,
timer.data.break_length,
goal=new_goal,
)
await ctx.reply("Stream Timer Updated")

View File

@@ -195,9 +195,7 @@ class Timer:
Uses voice channel member cache as source-of-truth. Uses voice channel member cache as source-of-truth.
""" """
if (chan := self.channel): if (chan := self.channel):
members = [ members = [m for m in chan.members if not m.bot]
member for member in chan.members if not member.bot and 1148167212901859328 in [role.id for role in member.roles]
]
else: else:
members = [] members = []
return members return members
@@ -480,6 +478,7 @@ class Timer:
if self.guild.voice_client: if self.guild.voice_client:
await self.guild.voice_client.disconnect(force=True) await self.guild.voice_client.disconnect(force=True)
alert_file = focus_alert_path if stage.focused else break_alert_path alert_file = focus_alert_path if stage.focused else break_alert_path
try: try:
voice_client = await asyncio.wait_for( voice_client = await asyncio.wait_for(
self.channel.connect(timeout=30, reconnect=False), self.channel.connect(timeout=30, reconnect=False),
@@ -613,7 +612,11 @@ class Timer:
if render: if render:
try: try:
card = await get_timer_card(self.bot, self, stage) card = await get_timer_card(self.bot, self, stage)
await card.render() data = await card.render()
import io
with io.BytesIO(data) as buffer:
with open(f"pomodoro_{self.data.channelid}.png", "wb") as f:
f.write(buffer.getbuffer())
rawargs['file'] = card.as_file(f"pomodoro_{self.data.channelid}.png") rawargs['file'] = card.as_file(f"pomodoro_{self.data.channelid}.png")
except RenderingException: except RenderingException:
pass pass
@@ -841,8 +844,8 @@ class Timer:
to_next_stage = (current.end - utc_now()).total_seconds() to_next_stage = (current.end - utc_now()).total_seconds()
# TODO: Consider request rate and load # TODO: Consider request rate and load
if to_next_stage > 5 * 60 - drift: if to_next_stage > 1 * 60 - drift:
time_to_sleep = 5 * 60 time_to_sleep = 1 * 60
else: else:
time_to_sleep = to_next_stage time_to_sleep = to_next_stage

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import ShoutoutCog
def prepare(bot):
bot.add_cog(ShoutoutCog(bot))

View File

@@ -0,0 +1,90 @@
import asyncio
from typing import Optional
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from utils.lib import replace_multiple
from . import logger
from .data import ShoutoutData
class ShoutoutCog(commands.Cog):
# Future extension: channel defaults and config
DEFAULT_SHOUTOUT = """
We think that {name} is a great streamer and you should check them out \
and drop a follow! \
They {areorwere} streaming {game} at {channel}
"""
def __init__(self, bot: CrocBot):
self.bot = bot
self.data = bot.data.load_registry(ShoutoutData())
self.loaded = asyncio.Event()
async def cog_load(self):
await self.data.init()
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready') # type: ignore
async def on_ready(self):
await self.ensure_loaded()
async def cog_check(self, ctx):
await self.ensure_loaded()
return True
async def format_shoutout(self, text: str, user: twitchio.User):
channels = await self.bot.fetch_channels([user.id])
if channels:
channel = channels[0]
game = channel.game_name or 'Unknown'
else:
game = 'Unknown'
streams = await self.bot.fetch_streams([user.id])
live = bool(streams)
mapping = {
'{name}': user.display_name,
'{channel}': f"https://www.twitch.tv/{user.name}",
'{game}': game,
'{areorwere}': 'are' if live else 'were',
}
return replace_multiple(text, mapping)
@commands.command(aliases=['so'])
async def shoutout(self, ctx: commands.Context, user: twitchio.User):
# Make sure caller is mod/broadcaster
# Lookup custom shoutout for this user
# If it exists use it, otherwise use default shoutout
if (ctx.author.is_mod or ctx.author.is_broadcaster):
data = await self.data.CustomShoutout.fetch(int(user.id))
if data:
shoutout = data.content
else:
shoutout = self.DEFAULT_SHOUTOUT
formatted = await self.format_shoutout(shoutout, user)
await ctx.reply(formatted)
# TODO: How to /shoutout with lib?
@commands.command()
async def editshoutout(self, ctx: commands.Context, user: twitchio.User, *, text: str):
# Make sure caller is mod/broadcaster/user themselves(?)
# upsert/delete and insert (is upsert impl?)
if (ctx.author.is_mod or ctx.author.is_broadcaster or int(ctx.author.id) == int(user.id)):
await self.data.CustomShoutout.table.delete_where(userid=int(user.id))
if text and text.lower() not in ('reset', 'none'):
await self.data.CustomShoutout.create(
userid=int(user.id),
content=text,
)
await ctx.reply("Custom shoutout updated!")
else:
await ctx.reply("Custom shoutout removed.")

View File

@@ -0,0 +1,21 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class ShoutoutData(Registry):
class CustomShoutout(RowModel):
"""
Schema
------
CREATE TABLE shoutouts(
userid BIGINT PRIMARY KEY,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_tablename_ = 'shoutouts'
_cache_ = {}
userid = Integer(primary=True)
content = String()
created_at = Timestamp()

View File

@@ -115,7 +115,7 @@ class AlertCog(LionCog):
# Note we set page size to 100 # Note we set page size to 100
# So we should never get repeat or missed streams # So we should never get repeat or missed streams
# Since we can request a max of 100 userids anyway. # Since we can request a max of 100 userids anyway.
streaming[stream.user_id] = stream streaming[int(stream.user_id)] = stream
started = set(streaming.keys()).difference(self.live_streams.keys()) started = set(streaming.keys()).difference(self.live_streams.keys())
ended = set(self.live_streams.keys()).difference(streaming.keys()) ended = set(self.live_streams.keys()).difference(streaming.keys())
@@ -123,9 +123,9 @@ class AlertCog(LionCog):
for streamerid in started: for streamerid in started:
stream = streaming[streamerid] stream = streaming[streamerid]
stream_data = await self.data.Stream.create( stream_data = await self.data.Stream.create(
streamerid=stream.user_id, streamerid=int(stream.user_id),
start_at=stream.started_at, start_at=stream.started_at,
twitch_stream_id=stream.id, twitch_stream_id=int(stream.id),
game_name=stream.game_name, game_name=stream.game_name,
title=stream.title, title=stream.title,
) )
@@ -143,7 +143,7 @@ class AlertCog(LionCog):
async def on_stream_start(self, stream_data): async def on_stream_start(self, stream_data):
# Get channel subscriptions listening for this streamer # Get channel subscriptions listening for this streamer
uid = stream_data.streamerid uid = int(stream_data.streamerid)
logger.info(f"Streamer <uid:{uid}> started streaming! {stream_data=}") logger.info(f"Streamer <uid:{uid}> started streaming! {stream_data=}")
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid) subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
@@ -197,7 +197,7 @@ class AlertCog(LionCog):
return return
# Build message # Build message
streamer = await self.data.Streamer.fetch(stream_data.streamerid) streamer = await self.data.Streamer.fetch(int(stream_data.streamerid))
if not streamer: if not streamer:
# Streamer was deleted while handling the alert # Streamer was deleted while handling the alert
# Just quietly ignore # Just quietly ignore
@@ -235,7 +235,7 @@ class AlertCog(LionCog):
# Store sent alert # Store sent alert
alert = await self.data.StreamAlert.create( alert = await self.data.StreamAlert.create(
streamid=stream_data.streamid, streamid=int(stream_data.streamid),
subscriptionid=subscription.subscriptionid, subscriptionid=subscription.subscriptionid,
sent_at=utc_now(), sent_at=utc_now(),
messageid=message.id messageid=message.id
@@ -246,7 +246,7 @@ class AlertCog(LionCog):
async def on_stream_end(self, stream_data): async def on_stream_end(self, stream_data):
# Get channel subscriptions listening for this streamer # Get channel subscriptions listening for this streamer
uid = stream_data.streamerid uid = int(stream_data.streamerid)
logger.info(f"Streamer <uid:{uid}> stopped streaming! {stream_data=}") logger.info(f"Streamer <uid:{uid}> stopped streaming! {stream_data=}")
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid) subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
@@ -269,8 +269,8 @@ class AlertCog(LionCog):
async def sub_resolve(self, subscription, stream_data): async def sub_resolve(self, subscription, stream_data):
# Check if there is a current active alert to resolve # Check if there is a current active alert to resolve
alerts = await self.data.StreamAlert.fetch_where( alerts = await self.data.StreamAlert.fetch_where(
streamid=stream_data.streamid, streamid=int(stream_data.streamid),
subscriptionid=subscription.subscriptionid, subscriptionid=int(subscription.subscriptionid),
) )
if not alerts: if not alerts:
logger.info( logger.info(
@@ -322,7 +322,7 @@ class AlertCog(LionCog):
) )
else: else:
# Edit message with custom arguments # Edit message with custom arguments
streamer = await self.data.Streamer.fetch(stream_data.streamerid) streamer = await self.data.Streamer.fetch(int(stream_data.streamerid))
formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer) formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer)
formatted = await formatter(edit_setting.value) formatted = await formatter(edit_setting.value)
args = edit_setting.value_to_args(subscription.subscriptionid, formatted) args = edit_setting.value_to_args(subscription.subscriptionid, formatted)
@@ -400,7 +400,7 @@ class AlertCog(LionCog):
# Create streamer data if it doesn't already exist # Create streamer data if it doesn't already exist
streamer_data = await self.data.Streamer.fetch_or_create( streamer_data = await self.data.Streamer.fetch_or_create(
tw_user.id, int(tw_user.id),
login_name=tw_user.login, login_name=tw_user.login,
display_name=tw_user.display_name, display_name=tw_user.display_name,
) )
@@ -418,8 +418,10 @@ class AlertCog(LionCog):
self.watching[streamer_data.userid] = streamer_data self.watching[streamer_data.userid] = streamer_data
# Open AlertEditorUI for the new subscription # Open AlertEditorUI for the new subscription
# TODO
await ctx.reply("StreamAlert Created.") await ctx.reply("StreamAlert Created.")
ui = AlertEditorUI(bot=self.bot, sub_data=sub_data, callerid=ctx.author.id)
await ui.run(ctx.interaction)
await ui.wait()
async def alert_acmpl(self, interaction: discord.Interaction, partial: str): async def alert_acmpl(self, interaction: discord.Interaction, partial: str):
if not interaction.guild: if not interaction.guild:

View File

@@ -0,0 +1,8 @@
import logging
logger = logging.getLogger(__name__)
from .cog import TagCog
def prepare(bot):
bot.add_cog(TagCog(bot))

View File

@@ -0,0 +1,152 @@
import asyncio
from collections import defaultdict
from typing import Optional
import difflib
import twitchio
from twitchio.ext import commands
from meta import CrocBot
from utils.lib import utc_now
from . import logger
from .data import TagData
class TagCog(commands.Cog):
def __init__(self, bot: CrocBot):
self.bot = bot
self.data = bot.data.load_registry(TagData())
self.loaded = asyncio.Event()
# Cache of channel tags, channelid -> name.lower() -> Tag
self.tags: dict[int, dict[str, TagData.Tag]] = {}
async def load_tags(self):
tags = defaultdict(dict)
rows = await self.data.Tag.fetch_where()
for row in rows:
tags[row.channelid][row.name.lower()] = row
self.tags.clear()
self.tags.update(tags)
async def cog_load(self):
await self.data.init()
await self.load_tags()
self.loaded.set()
async def ensure_loaded(self):
if not self.loaded.is_set():
await self.cog_load()
@commands.Cog.event('event_ready')
async def on_ready(self):
await self.ensure_loaded()
# API
async def create_tag(self, channelid: int, name: str, content: str, created_by: int):
"""
Create a new Tag with the given parameters.
If the tag already exists, will raise (TODO)
"""
row = await self.data.Tag.create(
channelid=channelid,
name=name,
content=content,
created_by=created_by,
)
if (chantags := self.tags.get(channelid, None)) is None:
chantags = self.tags[channelid] = {}
chantags[name.lower()] = row
logger.info(f"Created Tag: {row!r}")
return row
# Commands
@commands.command()
async def edittag(self, ctx: commands.Context, tagname: str, *, content: str):
"""
Create or edit a tag.
"""
channelid = int((await ctx.channel.user()).id)
userid = int(ctx.author.id)
# Fetch the tag if it exists
tag = self.tags.get(channelid, {}).get(tagname.lower(), None)
if tag is None:
# Create new tag
tag = await self.create_tag(
channelid,
tagname,
content,
userid
)
await ctx.reply(f"Tag '{tagname}' created as #{tag.tagid}!")
else:
# Edit existing tag
if not (ctx.author.is_mod or ctx.author.is_broadcaster or userid == tag.created_by):
await ctx.reply("You can't edit this tag!")
return
await tag.update(
content=content,
updated_at=utc_now()
)
await ctx.reply(f"Updated '{tag.name}'")
@commands.command()
async def deltag(self, ctx: commands.Context, tagname: str):
if ctx.author.is_broadcaster or ctx.author.is_mod:
channelid = int((await ctx.channel.user()).id)
tag = self.tags.get(channelid, {}).get(tagname.lower(), None)
if tag is None:
await ctx.reply(f"Couldn't find '{tagname}' to delete!")
else:
self.tags[channelid].pop(tag.name.lower())
await tag.delete()
await ctx.reply(f"Deleted '{tag.name}'")
@commands.command()
async def tag(self, ctx: commands.Context, tagname: str):
channelid = int((await ctx.channel.user()).id)
tags = self.tags.get(channelid, {})
if (tag := tags.get(tagname.lower(), None)) is None:
# Search for closest match
matches = difflib.get_close_matches(tagname.lower(), tags.keys(), n=2)
matchstr = "'{}'".format("' or '".join(matches)) if matches else None
suffix = f"Did you mean {matchstr}?" if matches else ""
await ctx.reply(f"Couldn't find tag '{tagname}'! {suffix}")
return
await ctx.reply(tag.content)
@commands.command(name='tags')
async def cmd_tags(self, ctx: commands.Context, *, searchstr: str = ''):
"""
List the tags available in the current channel.
"""
channelid = int((await ctx.channel.user()).id)
tag_names = [tag.name for tag in self.tags.get(channelid, {}).values()]
matching = [name for name in tag_names if searchstr.lower() in name.lower()]
tagstr = ', '.join(matching)
if searchstr:
if matching:
await ctx.reply(f"Matching tags: {tagstr}")
else:
await ctx.reply(f"No tags matching '{searchstr}'")
else:
if matching:
await ctx.reply(f"Available tags: {tagstr}")
else:
await ctx.reply("No tags set up on this channel!")

View File

@@ -0,0 +1,30 @@
from data import Registry, RowModel
from data.columns import Integer, String, Timestamp
class TagData(Registry):
class Tag(RowModel):
"""
Schema
------
CREATE TABLE channel_tags(
tagid SERIAL PRIMARY KEY,
channelid BIGINT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_by BIGINT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX channel_tags_channelid_name ON channel_tags (channelid, name);
"""
_tablename_ = 'channel_tags'
_cache_ ={}
tagid = Integer(primary=True)
channelid = Integer()
name = String()
content = String()
created_by = Integer()
created_at = Timestamp()
updated_at = Timestamp()