diff --git a/data/migration/v12-13/migration.sql b/data/migration/v12-13/migration.sql index fcbf71ed..0b34f0d7 100644 --- a/data/migration/v12-13/migration.sql +++ b/data/migration/v12-13/migration.sql @@ -154,7 +154,10 @@ CREATE TYPE CoinTransactionType AS ENUM( 'VOICE_SESSION', 'TEXT_SESSION', 'ADMIN', - 'TASKS' + 'TASKS', + 'SCHEDULE_BOOK', + 'SCHEDULE_REWARD', + 'OTHER' ); @@ -795,6 +798,136 @@ CREATE TABLE channel_webhooks( -- }}} +-- Scheduled Sessions {{{ +/* Old Schema +CREATE TABLE accountability_slots( + slotid SERIAL PRIMARY KEY, + guildid BIGINT NOT NULL REFERENCES guild_config(guildid), + channelid BIGINT, + start_at TIMESTAMPTZ (0) NOT NULL, + messageid BIGINT, + closed_at TIMESTAMPTZ +); +CREATE UNIQUE INDEX slot_channels ON accountability_slots(channelid); +CREATE UNIQUE INDEX slot_guilds ON accountability_slots(guildid, start_at); +CREATE INDEX slot_times ON accountability_slots(start_at); + +CREATE TABLE accountability_members( + slotid INTEGER NOT NULL REFERENCES accountability_slots(slotid) ON DELETE CASCADE, + userid BIGINT NOT NULL, + paid INTEGER NOT NULL, + duration INTEGER DEFAULT 0, + last_joined_at TIMESTAMPTZ, + PRIMARY KEY (slotid, userid) +); +CREATE INDEX slot_members ON accountability_members(userid); +CREATE INDEX slot_members_slotid ON accountability_members(slotid); + +CREATE VIEW accountability_member_info AS + SELECT + * + FROM accountability_members + JOIN accountability_slots USING (slotid); + +CREATE VIEW accountability_open_slots AS + SELECT + * + FROM accountability_slots + WHERE closed_at IS NULL + ORDER BY start_at ASC; +*/ +-- Create new schema +CREATE TABLE schedule_slots( + slotid INTEGER PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE TABLE schedule_guild_config( + guildid BIGINT PRIMARY KEY REFERENCES guild_config ON DELETE CASCADE, + schedule_cost INTEGER, + reward INTEGER, + bonus_reward INTEGER, + min_attendance INTEGER, + lobby_channel BIGINT, + room_channel BIGINT, + blacklist_after INTEGER, + blacklistrole BIGINT, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE TABLE schedule_channels( + guildid BIGINT NOT NULL REFERENCES schedule_guild_config ON DELETE CASCADE, + channelid BIGINT NOT NULL, + PRIMARY KEY (guildid, channelid) +); + +CREATE TABLE schedule_sessions( + guildid BIGINT NOT NULL REFERENCES schedule_guild_config ON DELETE CASCADE, + slotid INTEGER NOT NULL REFERENCES schedule_slots ON DELETE CASCADE, + opened_at TIMESTAMPTZ, + closed_at TIMESTAMPTZ, + messageid BIGINT, + created_at TIMESTAMPTZ DEFAULT now(), + PRIMARY KEY (guildid, slotid) +); + +CREATE TABLE schedule_session_members( + guildid BIGINT NOT NULL, + userid BIGINT NOT NULL, + slotid INTEGER NOT NULL, + booked_at TIMESTAMPTZ NOT NULL DEFAULT now(), + attended BOOLEAN NOT NULL DEFAULT False, + clock INTEGER NOT NULL DEFAULT 0, + book_transactionid INTEGER REFERENCES coin_transactions, + reward_transactionid INTEGER REFERENCES coin_transactions, + PRIMARY KEY (guildid, userid, slotid), + FOREIGN KEY (guildid, userid) REFERENCES members ON DELETE CASCADE, + FOREIGN KEY (guildid, slotid) REFERENCES schedule_sessions (guildid, slotid) ON DELETE CASCADE +); +CREATE INDEX schedule_session_members_users ON schedule_session_members(userid, slotid); + +-- Migrate data +--- Create schedule_slots from accountability_slots +INSERT INTO schedule_slots (slotid) + SELECT EXTRACT(EPOCH FROM old_slots.start_time) + FROM (SELECT DISTINCT(start_at) AS start_time FROM accountability_slots) AS old_slots; + +--- Create schedule_guild_config from guild_config +INSERT INTO schedule_guild_config (guildid, schedule_cost, reward, bonus_reward, lobby_channel) + SELECT guildid, accountability_price, accountability_reward, accountability_bonus, accountability_lobby + FROM guild_config + WHERE guildid IN (SELECT DISTINCT(guildid) FROM accountability_slots); + +--- Update session rooms from accountability_slots +WITH open_slots AS ( + SELECT guildid, MAX(channelid) AS channelid + FROM accountability_slots + WHERE closed_at IS NULL + GROUP BY guildid +) +UPDATE schedule_guild_config +SET room_channel = open_slots.channelid +FROM open_slots +WHERE schedule_guild_config.guildid = open_slots.guildid; + +--- Create schedule_sessions from accountability_slots +INSERT INTO schedule_sessions (guildid, slotid, opened_at, closed_at) + SELECT guildid, new_slots.slotid, start_at, closed_at + FROM accountability_slots old_slots + LEFT JOIN schedule_slots new_slots + ON EXTRACT(EPOCH FROM old_slots.start_at) = new_slots.slotid; + +--- Create schedule_session_members from accountability_members +INSERT INTO schedule_session_members (guildid, userid, slotid, booked_at, attended, clock) + SELECT old_slots.guildid, members.userid, new_slots.slotid, old_slots.start_at, (members.duration > 0), members.duration + FROM accountability_members members + LEFT JOIN accountability_slots old_slots ON members.slotid = old_slots.slotid + LEFT JOIN schedule_slots new_slots + ON EXTRACT(EPOCH FROM old_slots.start_at) = new_slots.slotid; + +-- Drop old schema +-- }}} + INSERT INTO VersionHistory (version, author) VALUES (13, 'v12-v13 migration'); COMMIT; diff --git a/data/migration/v12-13/schedule.sql b/data/migration/v12-13/schedule.sql new file mode 100644 index 00000000..4583f727 --- /dev/null +++ b/data/migration/v12-13/schedule.sql @@ -0,0 +1,97 @@ +DROP TABLE IF EXISTS schedule_slots CASCADE; +DROP TABLE IF EXISTS schedule_guild_config CASCADE; +DROP TABLE IF EXISTS schedule_channels CASCADE; +DROP TABLE IF EXISTS schedule_sessions CASCADE; +DROP TABLE IF EXISTS schedule_session_members CASCADE; + +-- Create new schema + +CREATE TABLE schedule_slots( + slotid INTEGER PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE TABLE schedule_guild_config( + guildid BIGINT PRIMARY KEY REFERENCES guild_config ON DELETE CASCADE, + schedule_cost INTEGER, + reward INTEGER, + bonus_reward INTEGER, + min_attendance INTEGER, + lobby_channel BIGINT, + room_channel BIGINT, + blacklist_after INTEGER, + blacklist_role BIGINT, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE TABLE schedule_channels( + guildid BIGINT NOT NULL REFERENCES schedule_guild_config ON DELETE CASCADE, + channelid BIGINT NOT NULL, + PRIMARY KEY (guildid, channelid) +); + +CREATE TABLE schedule_sessions( + guildid BIGINT NOT NULL REFERENCES schedule_guild_config ON DELETE CASCADE, + slotid INTEGER NOT NULL REFERENCES schedule_slots ON DELETE CASCADE, + opened_at TIMESTAMPTZ, + closed_at TIMESTAMPTZ, + messageid BIGINT, + created_at TIMESTAMPTZ DEFAULT now(), + PRIMARY KEY (guildid, slotid) +); + +CREATE TABLE schedule_session_members( + guildid BIGINT NOT NULL, + userid BIGINT NOT NULL, + slotid INTEGER NOT NULL, + booked_at TIMESTAMPTZ NOT NULL DEFAULT now(), + attended BOOLEAN NOT NULL DEFAULT False, + clock INTEGER NOT NULL DEFAULT 0, + book_transactionid INTEGER REFERENCES coin_transactions, + reward_transactionid INTEGER REFERENCES coin_transactions, + PRIMARY KEY (guildid, userid, slotid), + FOREIGN KEY (guildid, userid) REFERENCES members ON DELETE CASCADE, + FOREIGN KEY (guildid, slotid) REFERENCES schedule_sessions (guildid, slotid) ON DELETE CASCADE +); +CREATE INDEX schedule_session_members_users ON schedule_session_members(userid, slotid); + +-- Migrate data +--- Create schedule_slots from accountability_slots +INSERT INTO schedule_slots (slotid) + SELECT EXTRACT(EPOCH FROM old_slots.start_time) + FROM (SELECT DISTINCT(start_at) AS start_time FROM accountability_slots) AS old_slots; + +--- Create schedule_guild_config from guild_config +INSERT INTO schedule_guild_config (guildid, schedule_cost, reward, bonus_reward, lobby_channel) + SELECT guildid, accountability_price, accountability_reward, accountability_bonus, accountability_lobby + FROM guild_config + WHERE guildid IN (SELECT DISTINCT(guildid) FROM accountability_slots); + +--- Update session rooms from accountability_slots +WITH open_slots AS ( + SELECT guildid, MAX(channelid) AS channelid + FROM accountability_slots + WHERE closed_at IS NULL + GROUP BY guildid +) +UPDATE schedule_guild_config +SET room_channel = open_slots.channelid +FROM open_slots +WHERE schedule_guild_config.guildid = open_slots.guildid; + +--- Create schedule_sessions from accountability_slots +INSERT INTO schedule_sessions (guildid, slotid, opened_at, closed_at) + SELECT guildid, new_slots.slotid, start_at, closed_at + FROM accountability_slots old_slots + LEFT JOIN schedule_slots new_slots + ON EXTRACT(EPOCH FROM old_slots.start_at) = new_slots.slotid; + +--- Create schedule_session_members from accountability_members +INSERT INTO schedule_session_members (guildid, userid, slotid, booked_at, attended, clock) + SELECT old_slots.guildid, members.userid, new_slots.slotid, old_slots.start_at, (members.duration > 0), members.duration + FROM accountability_members members + LEFT JOIN accountability_slots old_slots ON members.slotid = old_slots.slotid + LEFT JOIN schedule_slots new_slots + ON EXTRACT(EPOCH FROM old_slots.start_at) = new_slots.slotid; + +-- Drop old schema diff --git a/scripts/start_leo_debug.py b/scripts/start_leo_debug.py index ce58e1c9..5fa7809a 100755 --- a/scripts/start_leo_debug.py +++ b/scripts/start_leo_debug.py @@ -18,14 +18,14 @@ def loop_exception_handler(loop, context): print(context) task: asyncio.Task = context.get('task', None) if task is not None: - addendum = f"" + addendum = f"" message = context.get('message', '') context['message'] = ' '.join((message, addendum)) loop.default_exception_handler(context) event_loop.set_exception_handler(loop_exception_handler) -event_loop.set_debug(enabled=True) +# event_loop.set_debug(enabled=True) if __name__ == '__main__': diff --git a/src/core/__init__.py b/src/core/__init__.py index 0a7fa849..64672d49 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -1,6 +1,11 @@ from .cog import CoreCog from .config import ConfigCog +from babel.translator import LocalBabel + + +babel = LocalBabel('lion-core') + async def setup(bot): await bot.add_cog(CoreCog(bot)) diff --git a/src/core/lion.py b/src/core/lion.py index 8cf28740..a78a8fb0 100644 --- a/src/core/lion.py +++ b/src/core/lion.py @@ -1,9 +1,11 @@ from typing import Optional from cachetools import LRUCache +import itertools import datetime import discord from meta import LionCog, LionBot, LionContext +from utils.data import MEMBERS from data import WeakCache from .data import CoreData @@ -99,7 +101,7 @@ class Lions(LionCog): ('guildid',), *((guildid,) for guildid in missing) ).with_adapter(self.data.Guild._make_rows) - rows = (*rows, *new_rows) + rows = itertools.chain(rows, new_rows) for row in rows: guildid = row.guildid @@ -107,6 +109,35 @@ class Lions(LionCog): return guild_map + async def fetch_users(self, *userids) -> dict[int, LionUser]: + """ + Fetch (or create) multiple LionUsers simultaneously, using cache where possible. + """ + user_map = {} + missing = set() + for userid in userids: + luser = self.lion_users.get(userid, None) + user_map[userid] = luser + if luser is None: + missing.add(userid) + + if missing: + rows = await self.data.User.fetch_where(userid=list(missing)) + missing.difference_update(row.userid for row in rows) + + if missing: + new_rows = await self.data.User.table.insert_many( + ('userid',), + *((userid,) for userid in missing) + ).with_adapter(self.data.user._make_rows) + rows = itertools.chain(rows, new_rows) + + for row in rows: + userid = row.userid + self.lion_users[userid] = user_map[userid] = LionUser(self.bot, row) + + return user_map + async def fetch_member(self, guildid, userid, member: Optional[discord.Member] = None) -> LionMember: """ Fetch the given LionMember, using cache for data if possible. @@ -124,11 +155,46 @@ class Lions(LionCog): self.lion_members[key] = lmember return lmember - async def fetch_members(self, *members: tuple[int, int]): + async def fetch_members(self, *memberids: tuple[int, int]) -> dict[tuple[int, int], LionMember]: """ Fetch or create multiple members simultaneously. """ - # TODO: Actually batch this (URGENT) - members = {} - for key in members: - members[key] = await self.fetch_member(*key) + member_map = {} + missing = set() + + # Retrieve what we can from cache + for memberid in memberids: + lmember = self.lion_members.get(memberid, None) + member_map[memberid] = lmember + if lmember is None: + missing.add(memberid) + + # Fetch or create members that weren't in cache + if missing: + # First fetch or create the guilds and users + lguilds = await self.fetch_guilds(*(gid for gid, _ in missing)) + lusers = await self.fetch_users(*(uid for _, uid in missing)) + + # Now attempt to load members from data + rows = await self.data.Member.fetch_where(MEMBERS(*missing)) + missing.difference_update((row.guildid, row.userid) for row in rows) + + # Create any member rows that are still missing + if missing: + new_rows = await self.data.Member.table.insert_many( + ('guildid', 'userid'), + *missing + ).with_adapter(self.data.Member._make_rows) + rows = itertools.chain(rows, new_rows) + + # We have all the data, now construct the member objects + for row in rows: + key = (row.guildid, row.userid) + self.lion_members[key] = member_map[key] = LionMember( + self.bot, + row, + lguilds[row.guildid], + lusers[row.userid] + ) + + return member_map diff --git a/src/core/setting_types.py b/src/core/setting_types.py new file mode 100644 index 00000000..5471b490 --- /dev/null +++ b/src/core/setting_types.py @@ -0,0 +1,64 @@ +""" +Additional abstract setting types useful for StudyLion settings. +""" +from settings.setting_types import IntegerSetting +from meta import conf +from meta.errors import UserInputError +from constants import MAX_COINS +from babel.translator import ctx_translator + +from . import babel + +_p = babel._p + + +class CoinSetting(IntegerSetting): + """ + Setting type mixin describing a LionCoin setting. + """ + _min = 0 + _max = MAX_COINS + + _accepts = _p('settype:coin|accepts', "A positive integral number of coins.") + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + """ + Parse the user input into an integer. + """ + if not string: + return None + try: + num = int(string) + except Exception: + t = ctx_translator.get().t + + raise UserInputError(t(_p( + 'settype:coin|parse|error:notinteger', + "The coin quantity must be a positive integer!" + ))) from None + + if num > cls._max: + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'settype:coin|parse|error:too_large', + "Provided number of coins was too high!" + ))) from None + elif num < cls._min: + t = ctx_translator.get().t + raise UserInputError(t(_p( + 'settype:coin|parse|error:too_large', + "Provided number of coins was too low!" + ))) from None + + return num + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + if data is not None: + t = ctx_translator.get().t + formatted = t(_p( + 'settype:coin|formatted', + "{coin}**{amount}**" + )).format(coin=conf.emojis.coin, amount=data) + return formatted diff --git a/src/data/models.py b/src/data/models.py index 36ef7ff4..54b62824 100644 --- a/src/data/models.py +++ b/src/data/models.py @@ -137,7 +137,7 @@ class RowModel: _registry: Optional[Registry] = None # TODO: Proper typing for a classvariable which gets dynamically assigned in subclass - table: RowTable + table: RowTable = None def __init_subclass__(cls: Type[RowT], table: Optional[str] = None): """ diff --git a/src/data/queries.py b/src/data/queries.py index aafd7c2e..a64d0950 100644 --- a/src/data/queries.py +++ b/src/data/queries.py @@ -121,7 +121,7 @@ class TableQuery(Query[QueryResult]): """ __slots__ = ( 'tableid', - 'condition', '_extra', '_limit', '_order', '_joins' + 'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group' ) def __init__(self, tableid, *args, **kwargs): @@ -282,6 +282,26 @@ class LimitMixin(TableQuery[QueryResult]): return None +class FromMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._from: Optional[Expression] = None + + def from_expr(self, _from: Expression): + self._from = _from + return self + + @property + def _from_section(self) -> Optional[Expression]: + if self._from is not None: + expr, values = self._from.as_tuple() + return RawExpr(sql.SQL("FROM {}").format(expr), values) + else: + return None + + class ORDER(Enum): ASC = sql.SQL('ASC') DESC = sql.SQL('DESC') @@ -331,6 +351,36 @@ class OrderMixin(TableQuery[QueryResult]): return None +class GroupMixin(TableQuery[QueryResult]): + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._group: list[Expression] = [] + + def group_by(self, *exprs: Union[Expression, str]): + """ + Add a group expression(s) to the query. + This method stacks. + """ + for expr in exprs: + if isinstance(expr, Expression): + self._group.append(expr) + else: + self._group.append(RawExpr(sql.Identifier(expr))) + return self + + @property + def _group_section(self) -> Optional[Expression]: + if self._group: + expr = RawExpr.join(*self._group, joiner=sql.SQL(', ')) + expr.expr = sql.SQL("GROUP BY {}").format(expr.expr) + return expr + else: + return None + + class Insert(ExtraMixin, TableQuery[QueryResult]): """ Query type representing a table insert query. @@ -411,7 +461,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): return RawExpr.join(*sections) -class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQuery[QueryResult]): +class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]): """ Select rows from a table matching provided conditions. """ @@ -464,6 +514,7 @@ class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQue RawExpr(base, columns_values), self._join_section, self._where_section, + self._group_section, self._extra_section, self._order_section, self._limit_section, @@ -495,7 +546,7 @@ class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): return RawExpr.join(*sections) -class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]): +class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]): __slots__ = ( '_set', ) @@ -534,6 +585,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]): ) sections = [ RawExpr(base, set_values), + self._from_section, self._where_section, self._extra_section, self._limit_section, diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index 67a3c1c1..cfdc6a80 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -172,7 +172,7 @@ class LionBot(Bot): "An unexpected error occurred while processing your command!\n" "Our development team has been notified, and the issue should be fixed soon.\n" "If the error persists, please contact our support team and give them the following number: " - f"`{ctx.interaction.id}`" + f"`{ctx.interaction.id if ctx.interaction else ctx.message.id}`" ) try: diff --git a/src/modules/__init__.py b/src/modules/__init__.py index d0db8273..ba855f3b 100644 --- a/src/modules/__init__.py +++ b/src/modules/__init__.py @@ -4,6 +4,7 @@ active = [ '.sysadmin', '.config', '.user_config', + '.schedule', '.economy', '.ranks', '.reminders', diff --git a/src/modules/economy/data.py b/src/modules/economy/data.py index ac9b39f1..bb8ef9e1 100644 --- a/src/modules/economy/data.py +++ b/src/modules/economy/data.py @@ -3,8 +3,8 @@ from enum import Enum from psycopg import sql from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr from data.columns import Integer, Bool, Column, Timestamp - from core.data import CoreData +from utils.data import TemporaryTable, SAFECOINS class TransactionType(Enum): @@ -12,20 +12,28 @@ class TransactionType(Enum): Schema ------ CREATE TYPE CoinTransactionType AS ENUM( - 'REFUND', - 'TRANSFER', - 'SHOP_PURCHASE', - 'STUDY_SESSION', - 'ADMIN', - 'TASKS' + 'REFUND', + 'TRANSFER', + 'SHOP_PURCHASE', + 'VOICE_SESSION', + 'TEXT_SESSION', + 'ADMIN', + 'TASKS', + 'SCHEDULE_BOOK', + 'SCHEDULE_REWARD', + 'OTHER' ); """ REFUND = 'REFUND', TRANSFER = 'TRANSFER', - PURCHASE = 'SHOP_PURCHASE', - SESSION = 'STUDY_SESSION', + SHOP_PURCHASE = 'SHOP_PURCHASE', + VOICE_SESSION = 'VOICE_SESSION', + TEXT_SESSION = 'TEXT_SESSION', ADMIN = 'ADMIN', TASKS = 'TASKS', + SCHEDULE_BOOK = 'SCHEDULE_BOOK', + SCHEDULE_REWARD = 'SCHEDULE_REWARD', + OTHER = 'OTHER', class AdminActionTarget(Enum): @@ -100,21 +108,99 @@ class EconomyData(Registry, name='economy'): from_account: int, to_account: int, amount: int, bonus: int = 0, refunds: int = None ): - transaction = await cls.create( - transactiontype=transaction_type, - guildid=guildid, actorid=actorid, amount=amount, bonus=bonus, - from_account=from_account, to_account=to_account, - refunds=refunds - ) - if from_account is not None: - await CoreData.Member.table.update_where( - guildid=guildid, userid=from_account - ).set(coins=(CoreData.Member.coins - (amount + bonus))) - if to_account is not None: - await CoreData.Member.table.update_where( - guildid=guildid, userid=to_account - ).set(coins=(CoreData.Member.coins + (amount + bonus))) - return transaction + conn = await cls._connector.get_connection() + async with conn.transaction(): + transaction = await cls.create( + transactiontype=transaction_type, + guildid=guildid, actorid=actorid, amount=amount, bonus=bonus, + from_account=from_account, to_account=to_account, + refunds=refunds + ) + if from_account is not None: + await CoreData.Member.table.update_where( + guildid=guildid, userid=from_account + ).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus))) + if to_account is not None: + await CoreData.Member.table.update_where( + guildid=guildid, userid=to_account + ).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus))) + return transaction + + @classmethod + async def execute_transactions(cls, *transactions): + """ + Execute multiple transactions in one data transaction. + + Writes the transaction and updates the affected member accounts. + Returns the created Transactions. + + Arguments + --------- + transactions: tuple[TransactionType, int, int, int, int, int, int, int] + (transaction_type, guildid, actorid, from_account, to_account, amount, bonus, refunds) + """ + if not transactions: + return [] + + conn = await cls._connector.get_connection() + async with conn.transaction(): + # Create the transactions + rows = await cls.table.insert_many( + ( + 'transactiontype', + 'guildid', 'actorid', + 'from_account', 'to_account', + 'amount', 'bonus', + 'refunds' + ), + *transactions + ).with_adapter(cls._make_rows) + + # Update the members + transtable = TemporaryTable( + '_guildid', '_userid', '_amount', + types=('BIGINT', 'BIGINT', 'INTEGER') + ) + values = transtable.values + for transaction in transactions: + _, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction + coins = amount + bonus + if coins: + if from_acc: + values.append((guildid, from_acc, -1 * coins)) + if to_acc: + values.append((guildid, to_acc, coins)) + if values: + Member = CoreData.Member + await Member.table.update_where( + guildid=transtable['_guildid'], userid=transtable['_userid'] + ).set( + coins=SAFECOINS(Member.coins + transtable['_amount']) + ).from_expr(transtable) + return rows + + @classmethod + async def refund_transactions(cls, *transactionids, actorid=0): + if not transactionids: + return [] + conn = await cls._connector.get_connection() + async with conn.transaction(): + # First fetch the transaction rows to refund + data = await cls.table.select_where(transactionid=transactionids) + if data: + # Build the transaction refund data + records = [ + ( + TransactionType.REFUND, + tr['guildid'], actorid, + tr['to_account'], tr['from_account'], + tr['amount'] + tr['bonus'], 0, + tr['transactionid'] + ) + for tr in data + ] + # Execute refund transactions + return await cls.execute_transactions(*records) class ShopTransaction(RowModel): """ diff --git a/src/modules/pomodoro/timer.py b/src/modules/pomodoro/timer.py index 2e4fd9fd..ad94eb84 100644 --- a/src/modules/pomodoro/timer.py +++ b/src/modules/pomodoro/timer.py @@ -141,6 +141,7 @@ class Timer: hook = self._hook = await self.bot.core.data.LionHook.fetch(cid) if not hook: # Attempt to create and save webhook + # TODO: Localise try: if channel.permissions_for(channel.guild.me).manage_webhooks: avatar = self.bot.user.avatar diff --git a/src/modules/schedule/__init__.py b/src/modules/schedule/__init__.py new file mode 100644 index 00000000..f601e910 --- /dev/null +++ b/src/modules/schedule/__init__.py @@ -0,0 +1,10 @@ +import logging +from babel.translator import LocalBabel + +logger = logging.getLogger(__name__) +babel = LocalBabel('schedule') + + +async def setup(bot): + from .cog import ScheduleCog + await bot.add_cog(ScheduleCog(bot)) diff --git a/src/modules/schedule/cog.py b/src/modules/schedule/cog.py new file mode 100644 index 00000000..6cf93ae5 --- /dev/null +++ b/src/modules/schedule/cog.py @@ -0,0 +1,882 @@ +from typing import Optional +from weakref import WeakValueDictionary +import datetime as dt +from collections import defaultdict +import asyncio + +import discord +from discord.ext import commands as cmds +from discord import app_commands as appcmds +from discord.app_commands import Range + +from meta import LionCog, LionBot, LionContext +from meta.logger import log_wrap +from meta.errors import UserInputError, ResponseTimedOut +from meta.sharding import THIS_SHARD +from utils.lib import utc_now, error_embed +from utils.ui import Confirm +from utils.data import MULTIVALUE_IN, MEMBERS +from wards import low_management_ward +from core.data import CoreData +from data import NULL, ORDER +from modules.economy.data import TransactionType +from constants import MAX_COINS + +from . import babel, logger +from .data import ScheduleData +from .settings import ScheduleSettings, ScheduleConfig +from .ui.scheduleui import ScheduleUI +from .ui.settingui import ScheduleSettingUI +from .core import TimeSlot, ScheduledSession, SessionMember +from .lib import slotid_to_utc, time_to_slotid + +_p, _np = babel._p, babel._np + + +class ScheduleCog(LionCog): + def __init__(self, bot: LionBot): + self.bot = bot + self.data: ScheduleData = bot.db.load_registry(ScheduleData()) + self.settings = ScheduleSettings() + + # Whether we are ready to take events + self.initialised = asyncio.Event() + + # Activated slot cache + self.active_slots: dict[int, TimeSlot] = {} # slotid -> TimeSlot + + # External modification (including spawing) a slot requires holding a slot lock + self._slot_locks = WeakValueDictionary() + + # Modifying a non-running slot or session requires holding the spawn lock + # This ensures the slot will not start while being modified + self.spawn_lock = asyncio.Lock() + + # Spawner loop task + self.spawn_task: Optional[asyncio.Task] = None + + self.session_channels = self.settings.SessionChannels._cache + + async def cog_load(self): + await self.data.init() + + # Update the session channel cache + await self.settings.SessionChannels.setup(self.bot) + + configcog = self.bot.get_cog('ConfigCog') + self.crossload_group(self.configure_group, configcog.configure_group) + + if self.bot.is_ready(): + await self.initialise() + + async def cog_unload(self): + """ + Cancel session spawning and the ongoing sessions. + """ + # TODO: Test/design for reload + if self.spawn_task and not self.spawn_task.done(): + self.spawn_task.cancel() + + for slot in list(self.active_slots.values()): + if slot.run_task and not slot.run_task.done(): + slot.run_task.cancel() + for session in slot.sessions.values(): + if session._updater and not session._updater.done(): + session._update.cancel() + if session._status_task and not session._status_task.done(): + session._status_task.cancel() + + @LionCog.listener('on_ready') + @log_wrap(action='Init Schedule') + async def initialise(self): + """ + Launch current timeslots, cleanup missed timeslots, and start the spawner. + """ + # Wait until voice session tracker has initialised + tracker = self.bot.get_cog('VoiceTrackerCog') + await tracker.initialised.wait() + + # Spawn the current session + now = utc_now() + nowid = time_to_slotid(now) + await self._spawner(nowid) + + # Start the spawner, with a small jitter based on shard id (for db loading) + spawn_start = now.replace(minute=30, second=0, microsecond=0) + spawn_start += dt.timedelta(seconds=self.bot.shard_id * 10) + self.spawn_task = asyncio.create_task(self._spawn_loop(start_at=spawn_start)) + + # Cleanup after missed or delayed timeslots + model = self.data.ScheduleSession + missed_session_data = await model.fetch_where( + model.slotid < nowid, + model.slotid > (nowid - 24 * 60 * 60), + model.closed_at == NULL, + THIS_SHARD + ) + if missed_session_data: + # Partition by slotid + slotid_session_data = defaultdict(list) + for row in missed_session_data: + slotid_session_data[row.slotid].append(row) + + # Fetch associated TimeSlots, oldest first + slot_data = await self.data.ScheduleSlot.fetch_where( + slotid=list(slotid_session_data.keys()) + ).order_by('slotid') + + # Process each slot + for row in slot_data: + try: + slot = TimeSlot(self, row) + sessions = await slot.load_sessions(slotid_session_data[slot.slotid]) + await slot.cleanup(list(sessions.values())) + except Exception: + logger.exception( + f"Unhandled exception while cleaning up missed timeslot {row!r}" + ) + self.initialised.set() + + @log_wrap(stack=['Schedule Spawner']) + async def _spawn_loop(self, start_at: dt.datetime): + """ + Every hour, starting at start_at, + the spawn loop will use `_spawner` to ensure the next slotid has been launched. + """ + next_spawn = start_at + while True: + try: + await discord.utils.sleep_until(next_spawn) + except asyncio.CancelledError: + break + next_spawn = next_spawn + dt.timedelta(hours=1) + try: + nextid = time_to_slotid(next_spawn) + await self._spawner(nextid) + except asyncio.CancelledError: + break + except Exception: + logger.exception( + "Unexpected error occurred while spawning scheduled sessions." + ) + + @log_wrap(action='Spawn') + async def _spawner(self, slotid): + """ + Ensure the provided slotid exists and is running. + """ + async with self.slotlock(slotid): + slot = self.active_slots.get(slotid, None) + if slot is None or slot.run_task is None: + slot_data = await self.data.ScheduleSlot.fetch_or_create(slotid) + slot = TimeSlot(self, slot_data) + await slot.fetch() + self.active_slots[slotid] = slot + self._launch(slot) + logger.info(f"Spawned Schedule TimeSlot ") + + def _launch(self, slot: TimeSlot): + launch_task = slot.launch() + key = slot.slotid + launch_task.add_done_callback(lambda fut: self.active_slots.pop(key, None)) + + # API + def slotlock(self, slotid): + lock = self._slot_locks.get(slotid, None) + if lock is None: + lock = self._slot_locks[slotid] = asyncio.Lock() + return lock + + @log_wrap(action='Cancel Booking') + async def cancel_bookings(self, *bookingids: tuple[int, int, int], refund=True): + """ + Cancel the provided bookings. + + bookingid: tuple[int, int, int] + Tuple of (slotid, guildid, userid) + """ + slotids = set(bookingid[0] for bookingid in bookingids) + locks = [self.slotlock(slotid) for slotid in slotids] + + # Request all relevant slotlocks + await asyncio.gather(*(lock.acquire() for lock in locks)) + try: + # TODO: Some benchmarking here + # Should we do the channel updates in bulk? + for bookingid in bookingids: + await self._cancel_booking_active(*bookingid) + + # Now delete from data + records = await self.data.ScheduleSessionMember.table.delete_where( + MULTIVALUE_IN( + ('slotid', 'guildid', 'userid'), + *bookingids + ) + ) + + # Refund cancelled bookings + if refund: + maybe_tids = (record['book_transactionid'] for record in records) + tids = [tid for tid in maybe_tids if tid is not None] + if tids: + economy = self.bot.get_cog('Economy') + await economy.data.Transaction.refund_transactions(*tids) + finally: + for lock in locks: + lock.release() + return records + + async def _cancel_booking_active(self, slotid, guildid, userid): + """ + Booking cancel worker for active slots. + + Does nothing if the provided bookingid is not active. + The slot lock MUST be taken before this is run. + """ + if not self.slotlock(slotid).locked(): + raise ValueError("Attempting to cancel active booking without taking slotlock.") + + slot = self.active_slots.get(slotid, None) + session = slot.sessions.get(guildid, None) if slot else None + member = session.pop(userid, None) if session else None + if member is not None: + if slot.closing.is_set(): + # Don't try to cancel a booking for a closing active slot. + return + async with session.lock: + # Update message if it has already been sent + session.update_message_soon(resend=False) + room = session.room_channel + member = session.guild.get_member(userid) if room else None + if room and member and session.prepared: + # Update channel permissions unless the member is in the next session and it is prepared + nextslotid = slotid + 3600 + nextslot = self.active_slots.get(nextslotid, None) + nextsession = nextslot.sessions.get(guildid, None) if nextslot else None + nextmember = (userid in nextsession.members) if nextsession else None + + unlock = None + try: + if nextmember: + unlock = nextsession.lock + await unlock.acquire() + update = (not nextsession.prepared) + else: + update = True + if update: + await room.set_permissions(member, overwrite=None) + except discord.HTTPException: + pass + finally: + if unlock is not None: + unlock.release() + elif slot is not None and member is None: + # Should not happen + logger.error( + f"Cancelling booking " + "for active slot " + "but the session member was not found. This should not happen." + ) + + @log_wrap(action='Clear Member Schedule') + async def clear_member_schedule(self, guildid, userid, refund=False): + """ + Cancel all current and future bookings for the given member. + """ + now = utc_now() + nowid = time_to_slotid(now) + + # First retrieve current and future booking data + bookings = await self.data.ScheduleSessionMember.fetch_where( + (ScheduleData.ScheduleSessionMember.slotid >= nowid), + guildid=guildid, + userid=userid, + ) + bookingids = [(b.slotid, guildid, userid) for b in bookings] + if bookingids: + await self.cancel_bookings(*bookingids, refund=refund) + + @log_wrap(action='Handle NoShow') + async def handle_noshow(self, *memberids): + """ + Handle "did not show" members. + + Typically cancels all future sessions for this member, + blacklists depending on guild settings, + and notifies the user. + """ + now = utc_now() + nowid = time_to_slotid(now) + member_model = self.data.ScheduleSessionMember + + # First handle blacklist + guildids, userids = map(set, zip(*memberids)) + # This should hit cache + config_data = await self.data.ScheduleGuild.fetch_multiple(*guildids) + autoblacklisting = {} + for gid, row in config_data.items(): + if row['blacklist_after'] and (rid := row['blacklist_role']): + guild = self.bot.get_guild(gid) + role = guild.get_role(rid) if guild else None + if role is not None: + autoblacklisting[gid] = (row['blacklist_after'], role) + + to_blacklist = {} + if autoblacklisting: + # Count number of missed sessions in the last 24h for each member in memberids + # who is also in an autoblacklisting guild + members = {} + for gid, uid in memberids: + if gid in autoblacklisting: + guild = self.bot.get_guild(gid) + member = guild.get_member(uid) if guild else None + if member: + members[(gid, uid)] = member + + if members: + missed = await member_model.table.select_where( + member_model.slotid < nowid, + member_model.slotid >= nowid - 24 * 3600, + MEMBERS(*members.keys()), + attended=False, + ).select( + guildid=member_model.guildid, + userid=member_model.userid, + missed="COUNT(slotid)" + ).group_by(member_model.guildid, member_model.userid).with_no_adapter() + for row in missed: + if row['missed'] >= autoblacklisting[row['guildid']][0]: + key = (row['guildid'], row['userid']) + to_blacklist[key] = members[key] + + if to_blacklist: + # Actually apply blacklist + tasks = [] + for (gid, uid), member in to_blacklist.items(): + role = autoblacklisting[gid][1] + task = asyncio.create_task(member.add_role(role)) + tasks.append(task) + # TODO: Logging and some error handling + await asyncio.gather(*tasks, return_exceptions=True) + + # Now cancel future sessions for members who were not blacklisted and are not currently clocked on + to_clear = [] + activeslot = self.active_slots[nowid] + for mid in memberids: + if mid not in to_blacklist: + gid, uid = mid + session = activeslot.sessions.get(gid, None) + member = session.members.get(uid, None) if session else None + clocked = (member is not None) and (member.clock_start is not None) + if not clocked: + to_clear.append(mid) + + if to_clear: + # Retrieve booking data + bookings = await member_model.fetch_where( + (member_model.slotid >= nowid), + MEMBERS(*to_clear) + ) + bookingids = [(b.slotid, b.guildid, b.userid) for b in bookings] + if bookingids: + await self.cancel_bookings(*bookingids, refund=False) + # TODO: Logging and error handling + + @log_wrap(action='Create Booking') + async def create_booking(self, guildid, userid, *slotids): + """ + Create new bookings with the given bookingids. + + Probably best refactored into an interactive method, + with some parts in slot and session. + """ + t = self.bot.translator.t + locks = [self.slotlock(slotid) for slotid in slotids] + await asyncio.gather(*(lock.acquire() for lock in locks)) + try: + conn = await self.bot.db.get_connection() + async with conn.transaction(): + # Validate bookings + guild_data = await self.data.ScheduleGuild.fetch_or_create(guildid) + config = ScheduleConfig(guildid, guild_data) + + # Check guild lobby exists + if config.get(ScheduleSettings.SessionLobby.setting_id).value is None: + error = t(_p( + 'create_booking|error:no_lobby', + "This server has not set a `session_lobby`, so the scheduled session system is disabled!" + )) + raise UserInputError(error) + + # Fetch up to data lion data and member data + lion = await self.bot.core.lions.fetch_member(guildid, userid) + member = await lion.fetch_member() + await lion.data.refresh() + if not member: + # This should pretty much never happen unless something went wrong on Discord's end + error = t(_p( + 'create_booking|error:no_member', + "An unknown Discord error occurred. Please try again in a few minutes." + )) + raise UserInputError(error) + + # Check member blacklist + if (role := config.get(ScheduleSettings.BlacklistRole.setting_id).value) and role in member.roles: + error = t(_p( + 'create_booking|error:blacklisted', + "You have been blacklisted from the scheduled session system in this server." + )) + raise UserInputError(error) + + # Check member balance + requested = len(slotids) + required = len(slotids) * config.get(ScheduleSettings.ScheduleCost.setting_id).value + balance = lion.data.coins + if balance < required: + error = t(_np( + 'create_booking|error:insufficient_balance', + "Booking a session costs {coin}**{required}**, but you only have {coin}**{balance}**.", + "Booking `{count}` sessions costs {coin}**{required}**, but you only have {coin}**{balance}**.", + requested + )).format( + count=requested, coin=self.bot.config.emojis.coin, + required=required, balance=balance + ) + raise UserInputError(error) + + # Check existing bookings + schedule = await self._fetch_schedule(userid) + if set(slotids).intersection(schedule.keys()): + error = t(_p( + 'create_booking|error:already_booked', + "One or more requested timeslots are already booked!" + )) + raise UserInputError(error) + + # Booking request is now validated. Perform bookings. + + # Fetch or create session data + await self.data.ScheduleSlot.fetch_multiple(*slotids) + session_data = await self.data.ScheduleSession.fetch_multiple( + *((guildid, slotid) for slotid in slotids) + ) + + # Create transactions + economy = self.bot.get_cog('Economy') + trans_data = ( + TransactionType.SCHEDULE_BOOK, + guildid, userid, userid, 0, + config.get(ScheduleSettings.ScheduleCost.setting_id).value, + 0, None + ) + transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids)) + transactionids = [row.transactionid for row in transactions] + + # Create bookings + now = utc_now() + booking_data = await self.data.ScheduleSessionMember.table.insert_many( + ('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'), + *( + (guildid, userid, slotid, now, tid) + for slotid, tid in zip(slotids, transactionids) + ) + ) + + # Now pass to activated slots + for record in booking_data: + slotid = record['slotid'] + if (slot := self.active_slots.get(slotid, None)): + session = slot.sessions.get(guildid, None) + if session is None: + # Create a new session in the slot and set it up + session = await slot.load_sessions(session_data[guildid, slotid]) + slot.sessions[guildid] = session + if slot.closing.is_set(): + # This should never happen + logger.error( + "Attempt to book a session in a closing slot. This should be impossible." + ) + raise ValueError('Cannot book a session in a closing slot.') + elif slot.opening.is_set(): + await slot.open([session]) + elif slot.preparing.is_set(): + await slot.prepare([session]) + else: + # Session already exists in the slot + async with session.lock: + if session.prepared: + session.update_status_soon() + if (room := session.room_channel) and (mem := session.guild.get_member(userid)): + try: + await room.set_permissions( + mem, connect=True, view_channel=True + ) + except discord.HTTPException: + pass + finally: + for lock in locks: + lock.release() + # TODO: Logging and error handling + return booking_data + + # Event listeners + @LionCog.listener('on_member_update') + @log_wrap(action="Schedule Check Blacklist") + async def check_blacklist_role(self, before: discord.Member, after: discord.Member): + guild = before.guild + await self.initialised.wait() + before_roles = {role.id for role in before.roles} + new_roles = {role.id for role in after.roles if role.id not in before_roles} + if new_roles: + # This should be in cache in the vast majority of cases + guild_data = await self.data.ScheduleGuild.fetch(guild.id) + if (roleid := guild_data.blacklist_role) is not None and roleid in new_roles: + # Clear member schedule + await self.clear_member_schedule(guild.id, after.id) + + @LionCog.listener('on_member_remove') + @log_wrap(action="Schedule Member Remove") + async def clear_leaving_member(self, member: discord.Member): + """ + When a member leaves, clear their schedule + """ + await self.initialised.wait() + await self.clear_member_schedule(member.guild.id, member.id, refund=True) + + @LionCog.listener('on_guild_remove') + @log_wrap(action="Schedule Guild Remove") + async def clear_leaving_guild(self, guild: discord.Guild): + """ + When leaving a guild, delete all future bookings in the guild. + + This avoids penalising members for missing sessions in guilds we are not part of. + However, do not delete the guild sessions, + this allows seamless resuming if we rejoin the guild (aside from the cancelled sessions). + + Note that loaded sessions are independent of whether we are in the guild or not + (rather, we load all sessions that match this shard). + Hence we do not need to recreate the sessions when we join a new guild. + """ + await self.initialised.wait() + + now = utc_now() + nowid = time_to_slotid(now) + + bookings = await self.data.ScheduleSessionMember.fetch_where( + (ScheduleData.ScheduleSessionMember.slotid >= nowid), + guildid=guild.id + ) + bookingids = [(b.slotid, b.guildid, b.userid) for b in bookings] + if bookingids: + await self.cancel_bookings(*bookingids, refund=True) + + @LionCog.listener('on_voice_session_start') + @log_wrap(action="Schedule Clock On") + async def schedule_clockon(self, session_data): + try: + # DEBUG + logger.debug(f"Handling clock on parsing for {session_data}") + # Get current slot + now = utc_now() + nowid = time_to_slotid(now) + async with self.slotlock(nowid): + slot = self.active_slots.get(nowid, None) + if slot is not None: + # Get session in current slot + session = slot.sessions.get(session_data.guildid, None) + member = session.members.get(session_data.userid, None) if session else None + if member is not None: + async with session.lock: + if session.listening and session.validate_channel(session_data.channelid): + member.clock_on(session_data.start_time) + session.update_status_soon() + logger.debug( + f"Clocked on member {member.data!r} with session {session_data!r}" + ) + except Exception: + logger.exception( + f"Unexpected exception while clocking on voice sessions {session_data!r}" + ) + + @LionCog.listener('on_voice_session_end') + @log_wrap(action="Schedule Clock Off") + async def schedule_clockoff(self, session_data, ended_at): + try: + # DEBUG + logger.debug(f"Handling clock off parsing for {session_data}") + # Get current slot + now = utc_now() + nowid = time_to_slotid(now) + async with self.slotlock(nowid): + slot = self.active_slots.get(nowid, None) + if slot is not None: + # Get session in current slot + session = slot.sessions.get(session_data.guildid) + member = session.members.get(session_data.userid, None) if session else None + if member is not None: + async with session.lock: + if session.listening and session.validate_channel(session_data.channelid): + member.clock_off(ended_at) + session.update_status_soon() + logger.debug( + f"Clocked off member {member.data!r} from session {session_data!r}" + ) + except Exception: + logger.exception( + f"Unexpected exception while clocking off voice sessions {session_data!r}" + ) + + # Schedule commands + @cmds.hybrid_command( + name=_p('cmd:schedule', "schedule"), + description=_p( + 'cmd:schedule|desc', + "View and manage your scheduled session." + ) + ) + @appcmds.guild_only + async def schedule_cmd(self, ctx: LionContext): + # TODO: Auotocomplete for book and cancel options + # Will require TTL caching for member schedules. + book = None + cancel = None + if not ctx.guild: + return + if not ctx.interaction: + return + + t = self.bot.translator.t + guildid = ctx.guild.id + guild_data = await self.data.ScheduleGuild.fetch_or_create(guildid) + config = ScheduleConfig(guildid, guild_data) + now = utc_now() + lines: list[tuple[bool, str]] = [] # (error_status, msg) + + if cancel is not None: + schedule = await self._fetch_schedule(ctx.author.id) + # Validate provided + if not cancel.isdigit(): + # Error, slot {cancel} not recognised, please select a session to cancel from the acmpl list. + error = t(_p( + 'cmd:schedule|cancel_booking|error:parse_slot', + "Time slot `{provided}` not recognised. " + "Please select a session to cancel from the autocomplete options." + )) + line = (True, error) + elif (slotid := int(cancel)) not in schedule: + # Can't cancel slot because it isn't booked + error = t(_p( + 'cmd:schedule|cancel_booking|error:not_booked', + "Could not cancel {time} booking because it is not booked!" + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + line = (True, error) + elif (slotid_to_utc(slotid) - now).total_seconds() < 60: + # Can't cancel slot because it is running or about to start + error = t(_p( + 'cmd:schedule|cancel_booking|error:too_soon', + "Cannot cancel {time} booking because it is running or starting soon!" + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + line = (True, error) + else: + # Okay, slot is booked and cancellable. + # Actually cancel it + booking = schedule[slotid] + await self.cancel_bookings((booking.slotid, booking.guildid, booking.userid)) + # Confirm cancel done + ack = t(_p( + 'cmd:schedule|cancel_booking|success', + "Successfully cancelled your booking at {time}." + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + line = (False, ack) + lines.append(line) + + if book is not None: + schedule = await self._fetch_schedule(ctx.author.id) + if not book.isdigit(): + # Error, slot not recognised, please use autocomplete menu + error = t(_p( + 'cmd:schedule|create_booking|error:parse_slot', + "Time slot `{provided}` not recognised. " + "Please select a session to cancel from the autocomplete options." + )) + lines = (True, error) + elif (slotid := int(book)) in schedule: + # Can't book because the slot is already booked + error = t(_p( + 'cmd:schedule|create_booking|error:already_booked', + "You have already booked a scheduled session for {time}." + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + lines = (True, error) + elif (slotid_to_utc(slotid) - now).total_seconds() < 60: + # Can't book because it is running or about to start + error = t(_p( + 'cmd:schedule|create_booking|error:too_soon', + "Cannot book session at {time} because it is running or starting soon!" + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + line = (True, error) + else: + # The slotid is valid and bookable + # Run the booking + try: + await self.create_booking(guildid, ctx.author.id) + ack = t(_p( + 'cmd:schedule|create_booking|success', + "You have successfully scheduled a session at {time}." + )).format( + time=discord.utils.format_dt(slotid_to_utc(slotid), style='t') + ) + line = (False, ack) + except UserInputError as e: + line = (True, e.msg) + lines.append(line) + + if lines: + # Post lines + any_failed = False + text = [] + + for failed, msg in lines: + any_failed = any_failed or failed + emoji = self.bot.config.emojis.warning if failed else self.bot.config.emojis.tick + text.append(f"{emoji} {msg}") + + embed = discord.Embed( + colour=discord.Colour.brand_red() if any_failed else discord.Colour.brand_green(), + description='\n'.join(text) + ) + await ctx.interaction.edit_original_response(embed=embed) + else: + # Post ScheduleUI + ui = ScheduleUI(self.bot, ctx.guild, ctx.author.id) + await ui.run(ctx.interaction) + await ui.wait() + + async def _fetch_schedule(self, userid, **kwargs): + """ + Fetch the given user's schedule (i.e. booking map) + """ + nowid = time_to_slotid(utc_now()) + + booking_model = self.data.ScheduleSessionMember + bookings = await booking_model.fetch_where( + booking_model.slotid >= nowid, + userid=userid, + ).order_by('slotid', ORDER.ASC) + + return { + booking.slotid: booking for booking in bookings + } + + # Configuration + @LionCog.placeholder_group + @cmds.hybrid_group('configre', with_app_command=False) + async def configure_group(self, ctx: LionContext): + """ + Substitute configure command group. + """ + pass + + config_params = { + 'session_lobby': ScheduleSettings.SessionLobby, + 'session_room': ScheduleSettings.SessionRoom, + 'schedule_cost': ScheduleSettings.ScheduleCost, + 'attendance_reward': ScheduleSettings.AttendanceReward, + 'attendance_bonus': ScheduleSettings.AttendanceBonus, + 'min_attendance': ScheduleSettings.MinAttendance, + 'blacklist_role': ScheduleSettings.BlacklistRole, + 'blacklist_after': ScheduleSettings.BlacklistAfter, + } + + @configure_group.command( + name=_p('cmd:configure_schedule', "schedule"), + description=_p( + 'cmd:configure_schedule|desc', + "Configure Scheduled Session system" + ) + ) + @appcmds.rename( + **{param: option._display_name for param, option in config_params.items()} + ) + @appcmds.describe( + **{param: option._desc for param, option in config_params.items()} + ) + @low_management_ward + async def configure_schedule_command(self, ctx: LionContext, + session_lobby: Optional[discord.TextChannel | discord.VoiceChannel] = None, + session_room: Optional[discord.VoiceChannel] = None, + schedule_cost: Optional[appcmds.Range[int, 0, MAX_COINS]] = None, + attendance_reward: Optional[appcmds.Range[int, 0, MAX_COINS]] = None, + attendance_bonus: Optional[appcmds.Range[int, 0, MAX_COINS]] = None, + min_attendance: Optional[appcmds.Range[int, 1, 60]] = None, + blacklist_role: Optional[discord.Role] = None, + blacklist_after: Optional[appcmds.Range[int, 1, 24]] = None + ): + # Type Guards + if not ctx.guild: + return + if not ctx.interaction: + return + + # Map of parameter names to setting values + provided = { + 'session_lobby': session_lobby, + 'session_room': session_room, + 'schedule_cost': schedule_cost, + 'attendance_reward': attendance_reward, + 'attendance_bonus': attendance_bonus, + 'min_attendance': min_attendance, + 'blacklist_role': blacklist_role, + 'blacklist_after': blacklist_after, + } + modified = set(param for param, value in provided.items() if value is not None) + + # Make a config instance + guild_data = await self.data.ScheduleGuild.fetch_or_create(ctx.guild.id) + config = ScheduleConfig(ctx.guild.id, guild_data) + + if modified: + # Check provided values and build a list of write arguments + # Note that all settings are ModelSettings of ScheduleData.ScheduleGuild + lines = [] + update_args = {} + settings = [] + for param in modified: + # TODO: Add checks with setting._check_value + setting = self.config_params[param] + new_value = provided[param] + + instance = config.get(setting.setting_id) + instance.value = new_value + settings.append(instance) + update_args[instance._column] = instance._data + lines.append(instance.update_message) + + # Perform data update + await guild_data.update(**update_args) + # Dispatch setting updates to trigger hooks + for setting in settings: + setting.dispatch_update() + + # Ack modified settings + tick = self.bot.config.emojis.tick + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description='\n'.join(f"{tick} {line}" for line in lines) + ) + await ctx.reply(embed=embed) + + # Launch config UI if needed + if ctx.channel.id not in ScheduleSettingUI._listening or not modified: + ui = ScheduleSettingUI(self.bot, ctx.guild.id, ctx.channel.id) + await ui.run(ctx.interaction) + await ui.wait() diff --git a/src/modules/schedule/core/__init__.py b/src/modules/schedule/core/__init__.py new file mode 100644 index 00000000..1179a9cd --- /dev/null +++ b/src/modules/schedule/core/__init__.py @@ -0,0 +1,3 @@ +from .session_member import SessionMember +from .session import ScheduledSession +from .timeslot import TimeSlot diff --git a/src/modules/schedule/core/session.py b/src/modules/schedule/core/session.py new file mode 100644 index 00000000..95fb0688 --- /dev/null +++ b/src/modules/schedule/core/session.py @@ -0,0 +1,544 @@ +from typing import Optional +import datetime as dt +import asyncio + +import discord + +from meta import LionBot +from utils.lib import utc_now +from utils.lib import MessageArgs + +from .. import babel, logger +from ..data import ScheduleData as Data +from ..lib import slotid_to_utc +from ..settings import ScheduleSettings as Settings +from ..settings import ScheduleConfig +from ..ui.sessionui import SessionUI + +from .session_member import SessionMember + +_p = babel._p + +my_room_permissions = discord.Permissions( + connect=True, + view_channel=True, + manage_roles=True, + manage_permissions=True +) +member_room_permissions = discord.PermissionOverwrite( + connect=True, + view_channel=True +) + + +class ScheduledSession: + """ + Guild-local context for a scheduled session timeslot. + + Manages the status message and member list. + """ + update_interval = 60 + max_update_interval = 10 + + # TODO: Slots + # NOTE: All methods MUST permit the guild or channels randomly vanishing + # NOTE: All methods MUST be robust, and not propagate exceptions + # TODO: Guild locale context + def __init__(self, + bot: LionBot, + data: Data.ScheduleSession, config_data: Data.ScheduleGuild, + session_channels: Settings.SessionChannels): + self.bot = bot + self.data = data + self.slotid = data.slotid + self.guildid = data.guildid + self.config = ScheduleConfig(self.guildid, config_data) + self.channels_setting = session_channels + + self.starts_at = slotid_to_utc(self.slotid) + self.ends_at = slotid_to_utc(self.slotid + 3600) + + # Whether to listen to clock events + # should be set externally after the clocks have been initially set + self.listening = False + # Whether the session has prepared the room and sent the first message + # Also set by open() + self.prepared = False + # Whether the session has set the room permissions + self.opened = False + # Whether this session has been cancelled. Always set externally + self.cancelled = False + + self.members: dict[int, SessionMember] = {} + self.lock = asyncio.Lock() + + self.status_message = None + self._hook = None # Lobby webhook data + self._warned_hook = False + + self._last_update = None + self._updater = None + self._status_task = None + + # Setting shortcuts + @property + def room_channel(self) -> Optional[discord.VoiceChannel]: + return self.config.get(Settings.SessionRoom.setting_id).value + + @property + def lobby_channel(self) -> Optional[discord.TextChannel]: + return self.config.get(Settings.SessionLobby.setting_id).value + + @property + def bonus_reward(self) -> int: + return self.config.get(Settings.AttendanceBonus.setting_id).value + + @property + def attended_reward(self) -> int: + return self.config.get(Settings.AttendanceReward.setting_id).value + + @property + def min_attendence(self) -> int: + return self.config.get(Settings.MinAttendance.setting_id).value + + @property + def all_attended(self) -> bool: + return all(member.total_clock >= self.min_attendence for member in self.members.values()) + + @property + def can_run(self) -> bool: + """ + Returns True if this session exists and needs to run. + """ + return self.guild and self.members + + @property + def messageid(self) -> Optional[int]: + return self.status_message.id if self.status_message else None + + @property + def guild(self) -> Optional[discord.Guild]: + return self.bot.get_guild(self.guildid) + + def validate_channel(self, channelid) -> bool: + channel = self.bot.get_channel(channelid) + if channel is not None: + channels = self.channels_setting.value + return (not channels) or (channel in channels) or (channel.category and (channel.category in channels)) + else: + return False + + async def get_lobby_hook(self) -> Optional[discord.Webhook]: + """ + Fetch or create the webhook in the scheduled session lobby + """ + channel = self.lobby_channel + if channel: + cid = channel.id + if self._hook and self._hook.channelid == cid: + hook = self._hook + else: + hook = self._hook = await self.bot.core.data.LionHook.fetch(cid) + if not hook: + # Attempt to create + try: + if channel.permissions_for(channel.guild.me).manage_webhooks: + avatar = self.bot.user.avatar + avatar_data = (await avatar.to_file()).fp.read() if avatar else None + webhook = await channel.create_webhook( + avatar=avatar_data, + name=f"{self.bot.user.name} Scheduled Sessions", + reason="Scheduled Session Lobby" + ) + hook = await self.bot.core.data.LionHook.create( + channelid=cid, + token=webhook.token, + webhookid=webhook.id + ) + elif channel.permissions_for(channel.guild.me).send_messages and not self._warned_hook: + t = self.bot.translator.t + self._warned_hook = True + await channel.send( + t(_p( + 'session|error:lobby_webhook_perms', + "Insufficient permissions to create a webhook in this channel. " + "I require the `MANAGE_WEBHOOKS` permission." + )) + ) + except discord.HTTPException: + logger.warning( + "Unexpected Exception occurred while creating scheduled session lobby webhook.", + exc_info=True + ) + if hook: + return hook.as_webhook(client=self.bot) + + async def send(self, *args, wait=True, **kwargs): + lobby_hook = await self.get_lobby_hook() + if lobby_hook: + try: + return await lobby_hook.send(*args, wait=wait, **kwargs) + except discord.NotFound: + # Webhook was deleted under us + if self._hook is not None: + await self._hook.delete() + self._hook = None + except discord.HTTPException: + logger.warning( + f"Exception occurred sending to webhooks for scheduled session {self.data!r}", + exc_info=True + ) + + async def prepare(self, **kwargs): + """ + Execute prepare stage for this guild. + """ + async with self.lock: + await self.prepare_room() + await self.update_status(**kwargs) + self.prepared = True + + async def prepare_room(self): + """ + Add overwrites allowing current members to connect. + """ + async with self.lock: + if not (members := list(self.members.values())): + return + if not (guild := self.guild): + return + if not (room := self.room_channel): + return + + if room.permissions_for(guild.me) >= my_room_permissions: + # Add member overwrites + overwrites = room.overwrites + for member in members: + mobj = guild.get_member(member.userid) + if mobj: + overwrites[mobj] = discord.PermissionOverwrite(connect=True, view_channel=True) + try: + await room.edit(overwrites=overwrites) + except discord.HTTPException: + logger.warning( + f"Unexpected discord exception received while preparing schedule session room " + f"in guild for timeslot .", + exc_info=True + ) + else: + logger.debug( + f"Prepared schedule session room " + f"in guild for timeslot .", + ) + else: + t = self.bot.translator.t + await self.send( + t(_p( + 'session|prepare|error:room_permissions', + f"Could not prepare the configured session room {room} for the next scheduled session! " + "I require the `MANAGE_CHANNEL`, `MANAGE_ROLES`, `CONNECT` and `VIEW_CHANNEL` permissions." + )).format(room=room.mention) + ) + + async def open_room(self): + """ + Remove overwrites for non-members. + """ + async with self.lock: + if not (members := list(self.members.values())): + return + if not (guild := self.guild): + return + if not (room := self.room_channel): + return + + if room.permissions_for(guild.me) >= my_room_permissions: + # Replace the member overwrites + overwrites = { + target: overwrite for target, overwrite in room.overwrites.items() + if not isinstance(target, discord.Member) + } + for member in members: + mobj = guild.get_member(member.userid) + if mobj: + overwrites[mobj] = discord.PermissionOverwrite(connect=True, view_channel=True) + try: + await room.edit(overwrites=overwrites) + except discord.HTTPException: + logger.exception( + f"Unhandled discord exception received while opening schedule session room " + f"in guild for timeslot ." + ) + else: + logger.debug( + f"Opened schedule session room " + f"in guild for timeslot .", + ) + else: + t = self.bot.translator.t + await self.send( + t(_p( + 'session|open|error:room_permissions', + f"Could not set up the configured session room {room} for this scheduled session! " + "I require the `MANAGE_CHANNEL`, `MANAGE_ROLES`, `CONNECT` and `VIEW_CHANNEL` permissions." + )).format(room=room.mention) + ) + self.prepared = True + self.opened = True + + async def notify(self): + """ + Ghost ping members who have not yet attended. + """ + missing = [mid for mid, m in self.members.items() if m.total_clock == 0 and m.clock_start is None] + if missing: + ping = ''.join(f"<@{mid}>" for mid in missing) + message = await self.send(ping) + if message is not None: + asyncio.create_task(message.delete()) + + async def current_status(self) -> MessageArgs: + """ + Lobby status message args. + """ + t = self.bot.translator.t + now = utc_now() + + view = SessionUI(self.bot, self.slotid, self.guildid) + embed = discord.Embed( + colour=discord.Colour.orange(), + title=t(_p( + 'session|status|title', + "Session {start} - {end}" + )).format( + start=discord.utils.format_dt(self.starts_at, 't'), + end=discord.utils.format_dt(self.ends_at, 't'), + ) + ) + embed.timestamp = now + + if self.cancelled: + embed.description = t(_p( + 'session|status|desc:cancelled', + "I cancelled this scheduled session because I was unavailable. " + "All members who booked the session have been refunded." + )) + view = None + elif not self.members: + embed.description = t(_p( + 'session|status|desc:no_members', + "*No members scheduled this session.*" + )) + elif now < self.starts_at: + # Preparation stage + embed.description = t(_p( + 'session|status:preparing|desc:has_members', + "Starting {start}" + )).format(start=discord.utils.format_dt(self.starts_at, 'R')) + embed.add_field( + name=t(_p('session|status:preparing|field:members', "Members")), + value=', '.join(f"<@{m}>" for m in self.members) + ) + elif now < self.starts_at + dt.timedelta(hours=1): + # Running status + embed.description = t(_p( + 'session|status:running|desc:has_members', + "Finishing {start}" + )).format(start=discord.utils.format_dt(self.ends_at, 'R')) + + missing = [] + present = [] + min_attendence = self.min_attendence + for mid, member in self.members.items(): + clock = int(member.total_clock) + if clock == 0 and member.clock_start is None: + memstr = f"<@{mid}>" + missing.append(memstr) + else: + memstr = "<@{mid}> **({M:02}:{S:02})**".format( + mid=mid, + M=int(clock // 60), + S=int(clock % 60) + ) + present.append((memstr, clock, bool(member.clock_start))) + + waiting_for = [] + attending = [] + attended = [] + present.sort(key=lambda t: t[1], reverse=True) + for memstr, clock, clocking in present: + if clocking: + attending.append(memstr) + elif clock >= min_attendence: + attended.append(memstr) + else: + waiting_for.append(memstr) + waiting_for.extend(missing) + + if waiting_for: + embed.add_field( + name=t(_p('session|status:running|field:waiting', "Waiting For")), + value='\n'.join(waiting_for), + inline=True + ) + if attending: + embed.add_field( + name=t(_p('session|status:running|field:attending', "Attending")), + value='\n'.join(attending), + inline=True + ) + if attended: + embed.add_field( + name=t(_p('session|status:running|field:attended', "Attended")), + value='\n'.join(attended), + inline=True + ) + else: + # Finished, show summary + attended = [] + missed = [] + min_attendence = self.min_attendence + for mid, member in self.members.items(): + clock = int(member.total_clock) + memstr = "<@{mid}> **({M:02}:{S:02})**".format( + mid=mid, + M=int(clock // 60), + S=int(clock % 60) + ) + if clock < min_attendence: + missed.append(memstr) + else: + attended.append(memstr) + + if not missed: + # Everyone attended + embed.description = t(_p( + 'session|status:finished|desc:everyone_att', + "Everyone attended the session! " + "All members were rewarded with {coin} **{reward} + {bonus}**!" + )).format( + coin=self.bot.config.emojis.coin, + reward=self.attended_reward, + bonus=self.bonus_reward + ) + elif missed and attended: + # Mix of both + embed.description = t(_p( + 'session|status:finished|desc:some_att', + "Everyone who attended was rewarded with {coin} **{reward}**! " + "Some members did not attend so everyone missed out on the bonus {coin} **{bonus}**.\n" + "**Members who missed their session have all future sessions cancelled without refund!*" + )).format( + coin=self.bot.config.emojis.coin, + reward=self.attended_reward, + bonus=self.bonus_reward + ) + else: + # No-one attended + embed.description = t(_p( + 'session|status:finished|desc:some_att', + "No-one attended this session! No-one received rewards.\n" + "**Members who missed their session have all future sessions cancelled without refund!*" + )) + + if attended: + embed.add_field( + name=t(_p('session|status:finished|field:attended', "Attended")), + value='\n'.join(attended) + ) + if missed: + embed.add_field( + name=t(_p('session|status:finished|field:missing', "Missing")), + value='\n'.join(missed) + ) + view = None + + if view is not None: + await view.reload() + args = MessageArgs(embed=embed, view=view) + return args + + async def _update_status(self, save=True, resend=True): + """ + Send or update the lobby message. + """ + self._last_update = utc_now() + args = await self.current_status() + + message = self.status_message + if message is None and self.data.messageid is not None: + lobby_hook = await self.get_lobby_hook() + if lobby_hook: + try: + message = await lobby_hook.fetch_message(self.data.messageid) + except discord.HTTPException: + message = None + + repost = message is None + if not repost: + try: + await message.edit(**args.edit_args) + self.status_message = message + except discord.NotFound: + repost = True + self.status_message = None + except discord.HTTPException: + # Unexpected issue updating the message + logger.exception( + f"Exception occurred updating status for scheduled session {self.data!r}" + ) + + if repost and resend and self.members: + message = await self.send(**args.send_args) + self.status_message = message + if save: + await self.data.update(messageid=message.id if message else None) + + async def _update_status_soon(self, **kwargs): + try: + if self._last_update is not None: + next_update = self._last_update + dt.timedelta(seconds=self.max_update_interval) + await discord.utils.sleep_until(next_update) + task = asyncio.create_task(self._update_status(**kwargs)) + await asyncio.shield(task) + except asyncio.CancelledError: + pass + + def update_status_soon(self, **kwargs): + if self._status_task and not self._status_task.done(): + self._status_task.cancel() + self._status_task = asyncio.create_task(self._update_status_soon(**kwargs)) + + async def update_status(self, **kwargs): + if self._status_task and not self._status_task.done(): + self._status_task.cancel() + await self._update_status(**kwargs) + + async def update_loop(self): + """ + Keep the lobby message up to date with a message per minute. + Takes into account external and manual updates. + """ + try: + if self._last_update: + await discord.utils.sleep_until(self._last_update + dt.timedelta(seconds=self.update_interval)) + + while (now := utc_now()) <= self.ends_at: + await self.update_status() + while now < (next_update := (self._last_update + dt.timedelta(seconds=self.update_interval))): + await discord.utils.sleep_until(next_update) + now = utc_now() + await self.update_status() + except asyncio.CancelledError: + logger.debug( + f"Cancelled scheduled session update loop ,gid: {self.guildid}>" + ) + except Exception: + logger.exception( + "Unknown exception encountered during session " + f"update loop ,gid: {self.guildid}>" + ) + + def start_updating(self): + self._updater = asyncio.create_task(self.update_loop()) + return self._updater diff --git a/src/modules/schedule/core/session_member.py b/src/modules/schedule/core/session_member.py new file mode 100644 index 00000000..ddf594e6 --- /dev/null +++ b/src/modules/schedule/core/session_member.py @@ -0,0 +1,69 @@ +from typing import Optional +from collections import defaultdict +import datetime as dt +import asyncio +import itertools + +import discord + +from meta import LionBot +from utils.lib import utc_now +from core.lion_member import LionMember + +from .. import babel, logger +from ..data import ScheduleData as Data +from ..lib import slotid_to_utc + +_p = babel._p + + +class SessionMember: + """ + Member context for a scheduled session timeslot. + + Intended to keep track of members for ongoing and upcoming sessions. + Primarily used to track clock time and set attended status. + """ + # TODO: slots + + def __init__(self, + bot: LionBot, data: Data.ScheduleSessionMember, + lion: LionMember): + self.bot = bot + self.data = data + self.lion = lion + + self.slotid = data.slotid + self.slot_start = slotid_to_utc(self.slotid) + self.slot_end = slotid_to_utc(self.slotid + 3600) + self.userid = data.userid + self.guildid = data.guildid + + self.clock_start = None + self.clocked = 0 + + @property + def total_clock(self): + clocked = self.clocked + if self.clock_start is not None: + end = min(utc_now(), self.slot_end) + clocked += (end - self.clock_start).total_seconds() + return clocked + + def clock_on(self, at: dt.datetime): + """ + Mark this member as attending the scheduled session. + """ + if self.clock_start: + self.clock_off(at) + self.clock_start = max(self.slot_start, at) + + def clock_off(self, at: dt.datetime): + """ + Mark this member as no longer attending. + """ + if not self.clock_start: + raise ValueError("Member clocking off while already off.") + end = min(at, self.slot_end) + self.clocked += (end - self.clock_start).total_seconds() + self.clock_start = None diff --git a/src/modules/schedule/core/timeslot.py b/src/modules/schedule/core/timeslot.py new file mode 100644 index 00000000..a16640d8 --- /dev/null +++ b/src/modules/schedule/core/timeslot.py @@ -0,0 +1,529 @@ +from typing import Optional, TYPE_CHECKING +from collections import defaultdict +import datetime as dt +import asyncio + +import discord + +from meta import LionBot +from meta.sharding import THIS_SHARD +from meta.logger import log_context, log_wrap +from utils.lib import utc_now +from core.lion_member import LionMember +from core.lion_guild import LionGuild +from tracking.voice.session import SessionState +from utils.data import as_duration, MEMBERS, TemporaryTable +from modules.economy.cog import Economy +from modules.economy.data import EconomyData, TransactionType + +from .. import babel, logger +from ..data import ScheduleData as Data +from ..lib import slotid_to_utc, batchrun_per_second +from ..settings import ScheduleSettings + +from .session import ScheduledSession +from .session_member import SessionMember + +if TYPE_CHECKING: + from ..cog import ScheduleCog + +_p = babel._p + + +class TimeSlot: + """ + Represents a single schedule session timeslot. + + Maintains a cache of ScheduleSessions for event handling. + Responsible for the state of all scheduled sessions in this timeslot. + Provides methods for executing each stage of the time slot, + performing operations concurrently where possible. + """ + # TODO: Logging context + # TODO: Add per-shard jitter to improve ratelimit handling + + def __init__(self, cog: 'ScheduleCog', slot_data: Data.ScheduleSlot): + self.cog = cog + self.bot: LionBot = cog.bot + self.data: Data = cog.data + self.slot_data = slot_data + self.slotid = slot_data.slotid + log_context.set(f"slotid: {self.slotid}") + + self.prep_at = slotid_to_utc(self.slotid - 15*60) + self.start_at = slotid_to_utc(self.slotid) + self.end_at = slotid_to_utc(self.slotid + 3600) + + self.preparing = asyncio.Event() + self.opening = asyncio.Event() + self.opened = asyncio.Event() + self.closing = asyncio.Event() + + self.sessions: dict[int, ScheduledSession] = {} # guildid -> loaded ScheduledSession + self.run_task = None + self.loaded = False + + @log_wrap(action="Fetch sessions") + async def fetch(self): + """ + Load all slot sessions from data. Must be executed before reading event based updates. + + Does not take session lock because nothing external should read or modify before load. + """ + self.loaded = False + self.sessions.clear() + session_data = await self.data.ScheduleSession.fetch_where( + THIS_SHARD, + slotid=self.slotid, + closed_at=None, + ) + sessions = await self.load_sessions(session_data) + self.sessions.update(sessions) + self.loaded = True + logger.info( + f"Timeslot finished preloading {len(self.sessions)} guilds. Ready to open." + ) + + @log_wrap(action="Load sessions") + async def load_sessions(self, session_data) -> dict[int, ScheduledSession]: + """ + Load slot state for the provided GuildSchedule rows. + """ + if not session_data: + return {} + + guildids = [row.guildid for row in session_data] + + # Bulk fetch guild config data + config_data = await self.data.ScheduleGuild.fetch_multiple(*guildids) + + # Fetch channel data. This *should* hit cache if initialisation did its job + channel_settings = {guildid: await ScheduleSettings.SessionChannels.get(guildid) for guildid in guildids} + + # Data fetch all member schedules with this slotid + members = await self.data.ScheduleSessionMember.fetch_where( + slotid=self.slotid, + guildid=guildids + ) + # Bulk fetch lions + lions = await self.bot.core.lions.fetch_members( + *((m.guildid, m.userid) for m in members) + ) if members else {} + + # Partition member data + session_member_data = defaultdict(list) + for mem in members: + session_member_data[mem.guildid].append(mem) + + # Create the session guilds and session members. + sessions = {} + for row in session_data: + session = ScheduledSession(self.bot, row, config_data[row.guildid], channel_settings[row.guildid]) + smembers = {} + for memdata in session_member_data[row.guildid]: + smember = SessionMember( + self.bot, memdata, lions[memdata.guildid, memdata.userid] + ) + smembers[memdata.userid] = smember + session.members = smembers + sessions[row.guildid] = session + + logger.debug( + f"Timeslot " + f"loaded guild data for {len(sessions)} guilds: {', '.join(map(str, guildids))}" + ) + return sessions + + @log_wrap(action="Reset Clocks") + async def _reset_clocks(self, sessions: list[ScheduledSession]): + """ + Accurately set clocks (i.e. attendance time) for all tracked members in this time slot. + """ + now = utc_now() + tracker = self.bot.get_cog('VoiceTrackerCog') + tracking_lock = tracker.tracking_lock + session_locks = [session.lock for session in sessions] + + # Take the tracking lock so that sessions are not started/finished while we reset the clock + try: + await tracking_lock.acquire() + [await lock.acquire() for lock in session_locks] + if now > self.start_at + dt.timedelta(minutes=5): + # Set initial clocks based on session data + # First request sessions intersection with the timeslot + memberids = [ + (sm.data.guildid, sm.data.userid) + for sg in sessions for sm in sg.members.values() + ] + session_map = {session.guildid: session for session in sessions} + model = tracker.data.VoiceSessions + if memberids: + voice_sessions = await model.table.select_where( + MEMBERS(*memberids), + model.start_time < self.end_at, + model.start_time + as_duration(model.duration) > self.start_at + ).select( + 'guildid', 'userid', 'start_time', 'channelid', + end_time=model.start_time + as_duration(model.duration) + ).with_no_adapter() + else: + voice_sessions = [] + + # Intersect and aggregate sessions, accounting for session channels + clocks = defaultdict(int) + for vsession in voice_sessions: + if session_map[vsession['guildid']].validate_channel(vsession['channelid']): + start = max(vsession['start_time'], self.start_at) + end = min(vsession['end_time'], self.end_at) + clocks[(vsession['guildid'], vsession['userid'])] += (end - start).total_seconds() + + # Now write clocks + for sg in sessions: + for sm in sg.members.values(): + sg.clock = clocks[(sm.guildid, sm.userid)] + + # Mark current attendance using current voice session + for session in sessions: + for smember in session.members.values(): + voice_session = tracker.get_session(smember.data.guildid, smember.data.userid) + smember.clock_start = None + if voice_session is not None and voice_session.activity is SessionState.ONGOING: + if session.validate_channel(voice_session.data.channelid): + smember.clock_start = max(voice_session.data.start_time, self.start_at) + session.listening = True + finally: + tracking_lock.release() + [lock.release() for lock in session_locks] + + @log_wrap(action="Prepare Sessions") + async def prepare(self, sessions: list[ScheduledSession]): + """ + Bulk prepare ScheduledSessions for the upcoming timeslot. + + Preparing means sending the initial message and adding permissions for the next members. + This does not take the session lock for setting perms, because this is race-safe + (aside from potentially leaving extra permissions, which will be overwritten by `open`). + """ + logger.debug(f"Running prepare for time slot with {len(sessions)} sessions.") + try: + coros = [session.prepare(save=False) for session in sessions if session.can_run] + await batchrun_per_second(coros, 5) + + # Save messageids + tmptable = TemporaryTable( + '_gid', '_sid', '_mid', + types=('BIGINT', 'INTEGER', 'BIGINT') + ) + tmptable.values = [ + (sg.data.guildid, sg.data.slotid, sg.messageid) + for sg in sessions + if sg.messageid is not None + ] + await Data.ScheduleSession.table.update_where( + guildid=tmptable['_gid'], slotid=tmptable['_sid'] + ).set( + messageid=tmptable['_mid'] + ).from_expr(tmptable) + except Exception: + logger.exception( + f"Unhandled exception while preparing timeslot ." + ) + else: + logger.info( + f"Prepared {len(sessions)} for scheduled session timeslot " + ) + + @log_wrap(action="Open Sessions") + async def open(self, sessions: list[ScheduledSession]): + """ + Bulk open guild sessions. + + If session opens "late", uses voice session statistics to calculate clock times. + Otherwise, uses member's current sessions. + + Due to the bulk channel update, this method may take up to 5 or 10 minutes. + """ + try: + # List of sessions which have not been previously opened + # Used so that we only set channel permissions and notify and write opened once + fresh = [session for session in sessions if session.data.opened_at is None] + + # Calculate the attended time so far, referencing voice session data if required + await self._reset_clocks(sessions) + + # Bulk update lobby messages + message_tasks = [ + asyncio.create_task(session.update_status(save=False)) + for session in sessions + if session.lobby_channel is not None + ] + notify_tasks = [ + asyncio.create_task(session.notify()) + for session in fresh + if session.lobby_channel is not None and session.data.opened_at is None + ] + + # Start lobby update loops + for session in sessions: + session.start_updating() + + # Bulk run guild open to open session rooms + voice_coros = [ + session.open_room() + for session in fresh + if session.room_channel is not None and session.data.opened_at is None + ] + await batchrun_per_second(voice_coros, 5) + await asyncio.gather(*message_tasks) + await asyncio.gather(*notify_tasks) + + # Write opened + if fresh: + now = utc_now() + tmptable = TemporaryTable( + '_gid', '_sid', '_mid', '_open', + types=('BIGINT', 'INTEGER', 'BIGINT', 'TIMESTAMPTZ') + ) + tmptable.values = [ + (sg.data.guildid, sg.data.slotid, sg.messageid, now) + for sg in fresh + ] + await Data.ScheduleSession.table.update_where( + guildid=tmptable['_gid'], slotid=tmptable['_sid'] + ).set( + messageid=tmptable['_mid'], + opened_at=tmptable['_open'] + ).from_expr(tmptable) + except Exception: + logger.exception( + f"Unhandled exception while opening sessions for timeslot ." + ) + else: + logger.info( + f"Opened {len(sessions)} sessions for scheduled session timeslot " + ) + + @log_wrap(action="Close Sessions") + async def close(self, sessions: list[ScheduledSession], consequences=False): + """ + Close the session. + + Responsible for saving the member attendance, performing economy updates, + closing the guild sessions, and if `consequences` is set, + cancels future member sessions and blacklists members as required. + Also performs the last lobby message update for this timeslot. + + Does not modify session room channels (responsibility of the next open). + """ + try: + conn = await self.bot.db.get_connection() + async with conn.transaction(): + # Calculate rewards + rewards = [] + attendance = [] + did_not_show = [] + for session in sessions: + bonus = session.bonus_reward * session.all_attended + reward = session.attended_reward + bonus + required = session.min_attendence + for member in session.members.values(): + guildid = member.guildid + userid = member.userid + attended = (member.total_clock >= required) + if attended: + rewards.append( + (TransactionType.SCHEDULE_REWARD, + guildid, self.bot.user.id, + 0, userid, + reward, 0, + None) + ) + else: + did_not_show.append((guildid, userid)) + + attendance.append( + (self.slotid, guildid, userid, attended, member.total_clock) + ) + + # Perform economy transactions + economy: Economy = self.bot.get_cog('Economy') + transactions = await economy.data.Transaction.execute_transactions(*rewards) + reward_ids = { + (t.guildid, t.to_account): t.transactionid + for t in transactions + } + + # Update lobby messages + message_tasks = [ + asyncio.create_task(session.update_status(save=False)) + for session in sessions + if session.lobby_channel is not None + ] + await asyncio.gather(*message_tasks) + + # Save attendance + if attendance: + att_table = TemporaryTable( + '_sid', '_gid', '_uid', '_att', '_clock', '_reward', + types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER') + ) + att_table.values = [ + (sid, gid, uid, att, clock, reward_ids.get((gid, uid), None)) + for sid, gid, uid, att, clock in attendance + ] + await self.data.ScheduleSessionMember.table.update_where( + slotid=att_table['_sid'], + guildid=att_table['_gid'], + userid=att_table['_uid'], + ).set( + attended=att_table['_att'], + clock=att_table['_clock'], + reward_transactionid=att_table['_reward'] + ).from_expr(att_table) + + # Mark guild sessions as closed + if sessions: + await self.data.ScheduleSession.table.update_where( + slotid=self.slotid, + guildid=list(session.guildid for session in sessions) + ).set(closed_at=utc_now()) + + if consequences and did_not_show: + # Trigger blacklist and cancel member bookings as needed + await self.cog.handle_noshow(*did_not_show) + except Exception: + logger.exception( + f"Unhandled exception while closing sessions for timeslot ." + ) + else: + logger.info( + f"Closed {len(sessions)} for scheduled session timeslot " + ) + + def launch(self) -> asyncio.Task: + self.run_task = asyncio.create_task(self.run()) + return self.run_task + + @log_wrap(action="TimeSlot Run") + async def run(self): + """ + Execute each stage of the scheduled timeslot. + + Skips preparation if the open time has passed. + """ + if not self.loaded: + raise ValueError("Attempting to run a Session before loading.") + + try: + now = utc_now() + if now < self.start_at: + await discord.utils.sleep_until(self.prep_at) + self.preparing.set() + await self.prepare(list(self.sessions.values())) + else: + await discord.utils.sleep_until(self.start_at) + self.preparing.set() + self.opening.set() + await self.open(list(self.sessions.values())) + self.opened.set() + await discord.utils.sleep_until(self.end_at) + self.closing.set() + await self.close(list(self.sessions.values()), consequences=True) + except asyncio.CancelledError: + if self.closing.is_set(): + state = 'closing' + elif self.opened.is_set(): + state = 'opened' + elif self.opening.is_set(): + state = 'opening' + elif self.preparing.is_set(): + state = 'preparing' + logger.info( + f"Deactivating active time slot " + f"with state '{state}'." + ) + except Exception: + logger.exception( + f"Unexpected exception occurred while running active time slot ." + ) + + @log_wrap(action="Slot Cleanup") + async def cleanup(self, sessions: list[ScheduledSession]): + """ + Cleanup after "missed" ScheduledSessions. + + Missed sessions are unclosed sessions which are already past their closed time. + If the sessions were opened, they will be closed (with no consequences). + If the sessions were not opened, they will be cancelled (and the bookings refunded). + """ + now = utc_now() + if now < self.end_at: + raise ValueError("Attempting to cleanup sessions in current timeslot. Use close() or cancel() instead.") + + # Split provided sessions into ignore/close/cancel + to_close = [] + to_cancel = [] + for session in sessions: + if session.slotid != self.slotid: + raise ValueError(f"Timeslot {self.slotid} attempting to cleanup session with slotid {session.slotid}") + + if session.data.closed_at is not None: + # Already closed, ignore + pass + elif session.data.opened_at is not None: + # Session was opened, request close + to_close.append(session) + else: + # Session was never opened, request cancel + to_cancel.append(session) + + # Handle close + if to_close: + await self._reset_clocks(to_close) + await self.close(to_close, consequences=False) + + # Handle cancel + if to_cancel: + await self.cancel(to_cancel) + + @log_wrap(action="Cancel TimeSlot") + async def cancel(self, sessions: list[ScheduledSession]): + """ + Cancel the provided sessions. + + This involves refunding the booking transactions, deleting the booking rows, + and updating any messages that may have been posted. + """ + conn = await self.bot.db.get_connection() + async with conn.transaction(): + # Collect booking rows + bookings = [member.data for session in sessions for member in session.members.values()] + + if bookings: + # Refund booking transactions + economy: Economy = self.bot.get_cog('Economy') + maybe_tids = (r.book_transactionid for r in bookings) + tids = [tid for tid in maybe_tids if tid is not None] + await economy.data.Transaction.refund_transactions(*tids) + + # Delete booking rows + await self.data.ScheduleSessionMember.table.delete_where( + MEMBERS(*((r.guildid, r.userid) for r in bookings)), + slotid=self.slotid, + ) + + # Trigger message update for existent messages + lobby_tasks = [ + asyncio.create_task(session.update_status(save=False, resend=False)) + for session in sessions + ] + await asyncio.gather(*lobby_tasks) + + # Mark sessions as closed + await self.data.ScheduleSession.table.update_where( + slotid=self.slotid, + guildid=[session.guildid for session in sessions] + ).set( + closed_at=utc_now() + ) + # TODO: Logging diff --git a/src/modules/schedule/data.py b/src/modules/schedule/data.py new file mode 100644 index 00000000..1e007fe9 --- /dev/null +++ b/src/modules/schedule/data.py @@ -0,0 +1,156 @@ +from data import Registry, RowModel, Table +from data.columns import Integer, Timestamp, String, Bool +from utils.data import MULTIVALUE_IN + + +class ScheduleData(Registry): + class ScheduleSlot(RowModel): + """ + Schema + ------ + """ + _tablename_ = 'schedule_slots' + + slotid = Integer(primary=True) + created_at = Timestamp() + + @classmethod + async def fetch_multiple(cls, *slotids, create=True): + """ + Fetch multiple rows, applying cache where possible. + """ + results = {} + to_fetch = set() + for slotid in slotids: + row = cls._cache_.get(slotid, None) + if row is None: + to_fetch.add(slotid) + else: + results[slotid] = row + + if to_fetch: + rows = await cls.fetch_where(slotid=list(to_fetch)) + for row in rows: + results[row.slotid] = row + to_fetch.remove(row.slotid) + if to_fetch and create: + rows = await cls.table.insert_many( + ('slotid',), + *((slotid,) for slotid in to_fetch) + ).with_adapter(cls._make_rows) + for row in rows: + results[row.slotid] = row + return results + + class ScheduleSessionMember(RowModel): + """ + Schema + ------ + """ + _tablename_ = 'schedule_session_members' + + guildid = Integer(primary=True) + userid = Integer(primary=True) + slotid = Integer(primary=True) + booked_at = Timestamp() + attended = Bool() + clock = Integer() + book_transactionid = Integer() + reward_transactionid = Integer() + + class ScheduleSession(RowModel): + """ + Schema + ------ + """ + _tablename_ = 'schedule_sessions' + + guildid = Integer(primary=True) + slotid = Integer(primary=True) + opened_at = Timestamp() + closed_at = Timestamp() + messageid = Integer() + created_at = Timestamp() + + @classmethod + async def fetch_multiple(cls, *keys, create=True): + """ + Fetch multiple rows, applying cache where possible. + """ + # TODO: Factor this into a general multikey fetch many + results = {} + to_fetch = set() + for key in keys: + row = cls._cache_.get(key, None) + if row is None: + to_fetch.add(key) + else: + results[key] = row + + if to_fetch: + condition = MULTIVALUE_IN(cls._key_, *to_fetch) + rows = await cls.fetch_where(condition) + for row in rows: + results[row._rowid_] = row + to_fetch.remove(row._rowid_) + if to_fetch and create: + rows = await cls.table.insert_many( + cls._key_, + *to_fetch + ).with_adapter(cls._make_rows) + for row in rows: + results[row._rowid_] = row + return results + + class ScheduleGuild(RowModel): + """ + Schema + ------ + """ + _tablename_ = 'schedule_guild_config' + _cache_ = {} + + guildid = Integer(primary=True) + + schedule_cost = Integer() + reward = Integer() + bonus_reward = Integer() + min_attendance = Integer() + lobby_channel = Integer() + room_channel = Integer() + blacklist_after = Integer() + blacklist_role = Integer() + + @classmethod + async def fetch_multiple(cls, *guildids, create=True): + """ + Fetch multiple rows, applying cache where possible. + """ + results = {} + to_fetch = set() + for guildid in guildids: + row = cls._cache_.get(guildid, None) + if row is None: + to_fetch.add(guildid) + else: + results[guildid] = row + + if to_fetch: + rows = await cls.fetch_where(guildid=list(to_fetch)) + for row in rows: + results[row.guildid] = row + to_fetch.remove(row.guildid) + if to_fetch and create: + rows = await cls.table.insert_many( + ('guildid',), + *((guildid,) for guildid in to_fetch) + ).with_adapter(cls._make_rows) + for row in rows: + results[row.guildid] = row + return results + + """ + Schema + ------ + """ + schedule_channels = Table('schedule_channels') diff --git a/src/modules/schedule/lib.py b/src/modules/schedule/lib.py new file mode 100644 index 00000000..65e2d997 --- /dev/null +++ b/src/modules/schedule/lib.py @@ -0,0 +1,41 @@ +import asyncio +import itertools +import datetime as dt + +from utils.ratelimits import Bucket + + +def time_to_slotid(time: dt.datetime) -> int: + """ + Return the slotid for the provided time. + """ + utctime = time.astimezone(dt.timezone.utc) + hour = utctime.replace(minute=0, second=0, microsecond=0) + return int(hour.timestamp()) + + +def slotid_to_utc(sessionid: int) -> dt.datetime: + """ + Convert the given slotid (hour EPOCH) into a utc datetime. + """ + return dt.datetime.fromtimestamp(sessionid, tz=dt.timezone.utc) + + +async def batchrun_per_second(awaitables, batchsize): + """ + Run provided awaitables concurrently, + ensuring that no more than `batchsize` are running at once, + and that no more than `batchsize` are spawned per second. + + Returns list of returned results or exceptions. + """ + bucket = Bucket(batchsize, 1) + sem = asyncio.Semaphore(batchsize) + + tasks = [] + for awaitable in awaitables: + await asyncio.gather(bucket.wait(), sem.acquire()) + bucket.request() + task = asyncio.create_task(awaitable) + task.add_done_callback(lambda fut: sem.release()) + return await asyncio.gather(*tasks, return_exceptions=True) diff --git a/src/modules/schedule/settings.py b/src/modules/schedule/settings.py new file mode 100644 index 00000000..b0d8357b --- /dev/null +++ b/src/modules/schedule/settings.py @@ -0,0 +1,524 @@ +from collections import defaultdict +import discord + +from settings import ModelData, ListData +from settings.groups import SettingGroup, ModelConfig, SettingDotDict +from settings.setting_types import ( + ChannelSetting, IntegerSetting, ChannelListSetting, RoleSetting +) +from core.setting_types import CoinSetting +from meta import conf +from meta.errors import UserInputError +from meta.sharding import THIS_SHARD +from meta.logger import log_wrap + +from babel.translator import ctx_translator + +from . import babel, logger +from .data import ScheduleData + +_p = babel._p + + +class ScheduleConfig(ModelConfig): + settings = SettingDotDict() + _model_settings = set() + model = ScheduleData.ScheduleGuild + + +class ScheduleSettings(SettingGroup): + @ScheduleConfig.register_model_setting + class SessionLobby(ModelData, ChannelSetting): + setting_id = 'session_lobby' + _event = 'guildset_session_lobby' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:session_lobby', "session_lobby") + _desc = _p( + 'guildset:session_lobby|desc', + "Channel to post scheduled session announcement and status to." + ) + _long_desc = _p( + 'guildset:session_lobby|long_desc', + "Channel in which to announce scheduled sessions and post their status. " + "I must have the `MANAGE_WEBHOOKS` permission in this channel.\n" + "**This must be configured in order for the scheduled session system to function.**" + ) + _accepts = _p( + 'guildset:session_lobby|accepts', + "Name or id of the session lobby channel." + ) + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.lobby_channel.name + + @property + def update_message(self): + t = ctx_translator.get().t + if self.data: + resp = t(_p( + 'guildset:session_lobby|set_response|set', + "Scheduled sessions will now be announced in {channel}" + )).format(channel=self.formatted) + else: + resp = t(_p( + 'guildset:session_lobby|set_response|unset', + "The schedule session lobby has been unset. Shutting down scheduled session system." + )) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + t = ctx_translator.get().t + if data is None: + formatted = t(_p( + 'guildset:session_lobby|formatted|unset', + "`Not Set` (The scheduled session system is disabled.)" + )) + else: + formatted = t(_p( + 'guildset:session_lobby|formatted|set', + "<#{channelid}>" + )).format(channelid=data) + return formatted + + @ScheduleConfig.register_model_setting + class SessionRoom(ModelData, ChannelSetting): + setting_id = 'session_room' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:session_room', "session_room") + _desc = _p( + 'guildset:session_room|desc', + "Special voice channel open to scheduled session members." + ) + _long_desc = _p( + 'guildset:session_room|long_desc', + "If set, this voice channel serves as a dedicated room for scheduled session members. " + "During (and slightly before) each scheduled session, all members who have booked the session " + "will be given permission to join the voice channel (via permission overwrites). " + "I require the `MANAGE_CHANNEL`, `MANAGE_PERMISSIONS`, `CONNECT`, and `VIEW_CHANNEL` permissions " + "in this channel, and my highest role must be higher than all permission overwrites set in the channel." + ) + _accepts = _p( + 'guildset:session_room|accepts', + "Name or id of the session room voice channel." + ) + channel_types = [discord.VoiceChannel] + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.room_channel.name + + @property + def update_message(self): + t = ctx_translator.get().t + if self.data: + resp = t(_p( + 'guildset:session_room|set_response|set', + "Schedule session members will now be given access to {channel}" + )).format(channel=self.formatted) + else: + resp = t(_p( + 'guildset:session_room|set_response|unset', + "The dedicated schedule session room has been removed." + )) + return resp + + class SessionChannels(ListData, ChannelListSetting): + setting_id = 'session_channels' + + _display_name = _p('guildset:session_channels', "session_channels") + _desc = _p( + 'guildset:session_channels|desc', + "Voice channels in which to track activity for scheduled sessions." + ) + _long_desc = _p( + 'guildset:session_channels|long_desc', + "Only activity in these channels (and in `session_room` if set) will count towards " + "scheduled session attendance. If a category is selected, then all channels " + "under the category will also be included. " + "Activity tracking also respects the `untracked_voice_channels` setting." + ) + _accepts = _p( + 'guildset:session_channels|accepts', + "Comma separated list of session channel names or ids." + ) + _default = None + + _table_interface = ScheduleData.schedule_channels + _id_column = 'guildid' + _data_column = 'channelid' + _order_column = 'channelid' + + _cache = {} + + @property + def update_message(self): + t = ctx_translator.get().t + if self.data: + resp = t(_p( + 'guildset:session_channels|set_response|set', + "Activity in the following sessions will now count towards scheduled session attendance: {channels}" + )).format(channels=self.formatted) + else: + resp = t(_p( + 'guildset:session_channels|set_response|unset', + "Activity in all (tracked) voice channels will now count towards session attendance." + )) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + t = ctx_translator.get().t + if data is None: + formatted = t(_p( + 'guildset:session_channels|formatted|unset', + "All Channels (excluding `untracked_channels`)" + )) + else: + formatted = super()._format_data(parent_id, data, **kwargs) + return formatted + + @classmethod + @log_wrap(action='Cache Schedule Channels') + async def setup(cls, bot): + """ + Pre-load schedule channels for every guild on the current shard. + This includes guilds which the client cannot see. + """ + data = bot.db.registries['ScheduleData'] + + rows = await data.schedule_channels.select_where(THIS_SHARD) + new_cache = defaultdict(list) + count = 0 + for row in rows: + new_cache[row['guildid']].append(row['channelid']) + count += 1 + cls._cache.clear() + cls._cache.update(new_cache) + logger.info(f"Loaded {count} schedule session channels on this shard.") + + @ScheduleConfig.register_model_setting + class ScheduleCost(ModelData, CoinSetting): + setting_id = 'schedule_cost' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:schedule_cost', "schedule_cost") + _desc = _p( + 'guildset:schedule_cost|desc', + "Booking cost for each scheduled session." + ) + _long_desc = _p( + 'guildset:schedule_cost|long_desc', + "Members will be charged this many LionCoins for each scheduled session they book." + ) + _accepts = _p( + 'guildset:schedule_cost|accepts', + "Price of each session booking (non-negative integer)." + ) + _default = 100 + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.schedule_cost.name + + @property + def update_message(self) -> str: + t = ctx_translator.get().t + resp = t(_p( + 'guildset:schedule_cost|set_response', + "Schedule session bookings will now cost {coin} **{amount}** per timeslot." + )).format( + coin=conf.emojis.coin, + amount=self.value + ) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + if data is not None: + t = ctx_translator.get().t + formatted = t(_p( + 'guildset:schedule_cost|formatted', + "{coin}**{amount}** per booking." + )).format(coin=conf.emojis.coin, amount=data) + return formatted + + @ScheduleConfig.register_model_setting + class AttendanceReward(ModelData, CoinSetting): + setting_id = 'attendance_reward' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:attendance_reward', "attendance_reward") + _desc = _p( + 'guildset:attendance_reward|desc', + "Reward for attending a booked scheduled session." + ) + _long_desc = _p( + 'guildset:attendance_reward|long_desc', + "When a member successfully attends a scheduled session they booked, " + "they will be awarded this many LionCoins. " + "Should generally be more than the `schedule_cost` setting." + ) + _accepts = _p( + 'guildset:attendance_reward|accepts', + "Number of coins to reward session attendance." + ) + _default = 200 + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.reward.name + + @property + def update_message(self) -> str: + t = ctx_translator.get().t + resp = t(_p( + 'guildset:attendance_reward|set_response', + "Members will be rewarded {coin}**{amount}** when they attend a scheduled session." + )).format(coin=conf.emojis.coin, amount=self.value) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + if data is not None: + t = ctx_translator.get().t + formatted = t(_p( + 'guildset:attendance_reward|formatted', + "{coin}**{amount}** upon attendance." + )).format(coin=conf.emojis.coin, amount=data) + return formatted + + @ScheduleConfig.register_model_setting + class AttendanceBonus(ModelData, CoinSetting): + setting_id = 'attendance_bonus' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:attendance_bonus', "group_attendance_bonus") + _desc = _p( + 'guildset:attendance_bonus|desc', + "Bonus reward given when all members attend a scheduled session." + ) + _long_desc = _p( + 'guildset:attendance_bonus|long_desc', + "When all members who have booked a session successfully attend the session, " + "they will be given this bonus in *addition* to the `attendance_reward`." + ) + _accepts = _p( + 'guildset:attendance_bonus|accepts', + "Bonus coins rewarded when everyone attends a session." + ) + _default = 200 + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.bonus_reward.name + + @property + def update_message(self) -> str: + t = ctx_translator.get().t + resp = t(_p( + 'guildset:attendance_bonus|set_response', + "Session members will be rewarded an additional {coin}**{amount}** when everyone attends." + )).format(coin=conf.emojis.coin, amount=self.value) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + if data is not None: + t = ctx_translator.get().t + formatted = t(_p( + 'guildset:attendance_bonus|formatted', + "{coin}**{amount}** bonus when all booked members attend." + )).format(coin=conf.emojis.coin, amount=data) + return formatted + + @ScheduleConfig.register_model_setting + class MinAttendance(ModelData, IntegerSetting): + setting_id = 'min_attendance' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:min_attendance', "min_attendance") + _desc = _p( + 'guildset:min_attendance|desc', + "Minimum attendance before reward eligability." + ) + _long_desc = _p( + 'guildset:min_attendance|long_desc', + "Scheduled session members will need to attend the session for at least this number of minutes " + "before they are marked as having attended (and hence are rewarded)." + ) + _accepts = _p( + 'guildset:min_attendance|accepts', + "Number of minutes (1-60) before attendance is counted." + ) + _default = 10 + _min = 1 + _max = 60 + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.min_attendance.name + + @property + def update_message(self) -> str: + t = ctx_translator.get().t + resp = t(_p( + 'guildset:min_attendance|set_response', + "Members will be rewarded after they have attended booked sessions for at least **`{amount}`** minutes." + )).format(amount=self.value) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + if data is not None: + t = ctx_translator.get().t + formatted = t(_p( + 'guildset:min_attendance|formatted', + "**`{amount}`** minutes" + )).format(amount=data) + return formatted + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + if not string: + return None + + string = string.strip('m ') + + num = int(string) if string.isdigit() else None + try: + num = int(string) + except Exception: + num = None + + if num is None or not 0 < num < 60: + t = ctx_translator.get().t + error = t(_p( + 'guildset:min_attendance|parse|error', + "Minimum attendance must be an integer number of minutes between `1` and `60`." + )) + raise UserInputError(error) + + @ScheduleConfig.register_model_setting + class BlacklistRole(ModelData, RoleSetting): + setting_id = 'schedule_blacklist_role' + _set_cmd = 'configure schedule' + _event = 'guildset_schedule_blacklist_role' + + _display_name = _p('guildset:schedule_blacklist_role', "schedule_blacklist_role") + _desc = _p( + 'guildset:schedule_blacklist_role|desc', + "Role which disables scheduled session booking." + ) + _long_desc = _p( + 'guildset:schedule_blacklist_role|long_desc', + "Members with this role will not be allowed to book scheduled sessions in this server. " + "If the role is manually added, all future scheduled sessions for the user are cancelled. " + "This provides a way to stop repeatedly unreliable members from blocking the group bonus for all members. " + "Alternatively, consider setting the booking cost (and reward) very high to provide " + "a strong disincentive for not attending a session." + ) + _accepts = _p( + 'guildset:schedule_blacklist_role|accepts', + "Blacklist role name or id." + ) + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.blacklist_role.name + + @property + def update_message(self): + t = ctx_translator.get().t + if self.data: + resp = t(_p( + 'guildset:schedule_blacklist_role|set_response|set', + "Members with {role} will be unable to book scheduled sessions." + )).format(role=self.formatted) + else: + resp = t(_p( + 'guildset:schedule_blacklist_role|set_response|unset', + "The schedule blacklist role has been unset." + )) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + t = ctx_translator.get().t + if data is not None: + formatted = t(_p( + 'guildset:schedule_blacklist_role|formatted|set', + "{role} members will not be able to book scheduled sessions." + )).format(role=f"<&{data}>") + else: + formatted = t(_p( + 'guildset:schedule_blacklist_role|formatted|unset', + "Not Set" + )) + return formatted + + @ScheduleConfig.register_model_setting + class BlacklistAfter(ModelData, IntegerSetting): + setting_id = 'schedule_blacklist_after' + _set_cmd = 'configure schedule' + + _display_name = _p('guildset:schedule_blacklist_after', "schedule_blacklist_after") + _desc = _p( + 'guildset:schedule_blacklist_after|desc', + "Number of missed sessions within 24h before blacklisting." + ) + _long_desc = _p( + 'guildset:schedule_blacklist_after|long_desc', + "Members who miss more than this number of booked sessions in a single 24 hour period " + "will be automatically given the `blacklist_role`. " + "Has no effect if the `blacklist_role` is not set or if I do not have sufficient permissions " + "to assign the blacklist role." + ) + _accepts = _p( + 'guildset:schedule_blacklist_after|accepts', + "A number of missed sessions (1-24) before blacklisting." + ) + _default = None + _min = 1 + _max = 24 + + _model = ScheduleData.ScheduleGuild + _column = ScheduleData.ScheduleGuild.blacklist_after.name + + @property + def update_message(self) -> str: + t = ctx_translator.get().t + if self.data: + resp = t(_p( + 'guildset:schedule_blacklist_after|set_response|set', + "Members will be blacklisted after **`{amount}`** missed sessions within `24h`." + )).format(amount=self.data) + else: + resp = t(_p( + 'guildset:schedule_blacklist_after|set_response|unset', + "Members will not be automatically blacklisted from booking scheduled sessions." + )) + return resp + + @classmethod + def _format_data(cls, parent_id, data, **kwargs): + t = ctx_translator.get().t + if data is not None: + formatted = t(_p( + 'guildset:schedule_blacklist_after|formatted|set', + "Blacklist after **`{amount}`** missed sessions within `24h`." + )).format(amount=data) + else: + formatted = t(_p( + 'guildset:schedule_blacklist_after|formatted|unset', + "Do not automatically blacklist." + )) + return formatted + + @classmethod + async def _parse_string(cls, parent_id, string: str, **kwargs): + try: + return await super()._parse_string(parent_id, string, **kwargs) + except UserInputError: + t = ctx_translator.get().t + error = t(_p( + 'guildset:schedule_blacklist_role|parse|error', + "Blacklist threshold must be a number between `1` and `24`." + )) + raise UserInputError(error) from None diff --git a/src/modules/schedule/ui/scheduleui.py b/src/modules/schedule/ui/scheduleui.py new file mode 100644 index 00000000..9e409e8d --- /dev/null +++ b/src/modules/schedule/ui/scheduleui.py @@ -0,0 +1,660 @@ +from typing import Optional, TYPE_CHECKING +import asyncio +import math + +import discord +from discord.ui.select import select, Select, SelectOption +from discord.ui.button import button, Button, ButtonStyle + +from meta import conf, LionBot +from meta.errors import UserInputError +from data import ORDER + +from utils.ui import MessageUI, Confirm +from utils.lib import MessageArgs, utc_now, tabulate, error_embed +from babel.translator import ctx_translator + +from .. import babel, logger +from ..data import ScheduleData +from ..lib import slotid_to_utc, time_to_slotid +from ..settings import ScheduleConfig, ScheduleSettings + +_p, _np = babel._p, babel._np + +if TYPE_CHECKING: + from ..cog import ScheduleCog + from core.lion_member import LionMember + + +guide = _p( + 'ui:schedule|about', + "Guide tips here TBD" +) + + +class ScheduleUI(MessageUI): + """ + Primary UI pathway for viewing and modifying a member's schedule. + """ + + def __init__(self, bot: LionBot, guild: discord.Guild, callerid: int, **kwargs): + super().__init__(callerid=callerid, **kwargs) + self.bot = bot + self.cog: ScheduleCog = bot.get_cog('ScheduleCog') + self.guild = guild + + self.guildid = guild.id + self.userid = callerid + self.lion: LionMember = None + + # Data state + self.config: ScheduleConfig = None + self.blacklisted = False + self.schedule = {} # ordered map slotid -> ScheduleSessionMember + self.guilds = {} # Cache of guildid -> ScheduleGuild + + # Statistics + self.recent_stats = (0, 0) + self.recent_avg = 0 + self.all_stats = (0, 0) + self.all_avg = 0 + + self.streak = 0 + + # UI state + self.show_info = False + self.initial_load = False + self.now = utc_now() + self.nowid = time_to_slotid(self.now) + + # ----- API ----- + + # ----- UI Components ----- + # IDEA: History button? + + @button(emoji=conf.emojis.cancel, style=ButtonStyle.red) + async def quit_button(self, press: discord.Interaction, pressed: Button): + """ + Quit the schedule + """ + await press.response.defer() + await self.quit() + + @button(emoji=conf.emojis.refresh, style=ButtonStyle.grey) + async def refresh_button(self, press: discord.Interaction, pressed: Button): + """ + Refresh the schedule + """ + await press.response.defer(thinking=True, ephemeral=True) + self.show_info = False + self.initial_load = False + await self.refresh(thinking=press) + + @button(label='CLEAR_PLACEHOLDER', style=ButtonStyle.red) + async def clear_button(self, press: discord.Interaction, pressed: Button): + """ + Clear future sessions for this user. + """ + await press.response.defer(thinking=True, ephemeral=True) + t = self.bot.translator.t + + # First update the schedule + now = self.now = utc_now() + nowid = self.nowid = time_to_slotid(now) + nextid = nowid + 3600 + await self._load_schedule() + slotids = set(self.schedule.keys()) + + # Remove uncancellable slots + slotids.discard(nowid) + if (slotid_to_utc(nextid) - now).total_seconds() < 60: + slotids.discard(nextid) + if not slotids: + # Nothing to cancel + error = t(_p( + 'ui:schedule|button:clear|error:nothing', + "No upcoming sessions to cancel! Your schedule is already clear." + )) + embed = error_embed(error) + else: + # Do cancel + await self.cog.cancel_bookings( + *( + (slotid, self.schedule[slotid].guildid, self.userid) + for slotid in slotids + ) + ) + ack = t(_p( + 'ui:schedule|button:clear|success', + "Successfully cancelled and refunded your upcoming scheduled sessions." + )) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description=ack + ) + await press.edit_original_response(embed=embed) + self.show_info = False + await self.refresh() + + async def clear_button_refresh(self): + self.clear_button.label = self.bot.translator.t(_p( + 'ui:schedule|button:clear|label', + "Clear Schedule" + )) + if not self.schedule: + self.clear_button.disabled = True + + @button(label='ABOUT_PLACEHOLDER', emoji=conf.emojis.question, style=ButtonStyle.grey) + async def about_button(self, press: discord.Interaction, pressed: Button): + """ + Replace message with the info page (temporarily). + """ + await press.response.defer(thinking=True, ephemeral=True) + self.show_info = not self.show_info + await self.refresh(thinking=press) + + async def about_button_refresh(self): + self.about_button.label = self.bot.translator.t(_p( + 'ui:schedule|button:about|label', + "About Schedule" + )) + self.about_button.style = ButtonStyle.grey if self.show_info else ButtonStyle.blurple + + @select(cls=Select, placeholder='BOOK_MENU_PLACEHOLDER') + async def booking_menu(self, selection: discord.Interaction, selected): + if selected.values[0] == 'None': + await selection.response.defer() + return + + await selection.response.defer(thinking=True, ephemeral=True) + t = self.bot.translator.t + + # Refresh the schedule + now = self.now = utc_now() + nowid = self.nowid = time_to_slotid(now) + nextid = nowid + 3600 + next_soon = ((slotid_to_utc(nextid) - now).total_seconds() < 60) + await self._load_schedule() + + # Check the requested slots + slotids = set(map(int, selected.values)) + if nowid in slotids: + # Error with cannot book now + error = t(_p( + 'ui:schedule|menu:booking|error:current_slot', + "You cannot schedule a currently running session!" + )) + embed = error_embed(error) + elif (nextid in slotids) and next_soon: + # Error with too late + error = t(_p( + 'ui:schedule|menu:booking|error:next_slot', + "Too late! You cannot schedule a session starting in the next minute." + )) + embed = error_embed(error) + elif slotids.intersection(self.schedule.keys()): + # Error with already booked + error = t(_p( + 'ui:schedule|menu:booking|error:already_booked', + "You have already booked one or more of the requested sessions!" + )) + embed = error_embed(error) + else: + # Okay, slotids are valid. + # Check member balance is sufficient + await self.lion.data.refresh() + balance = self.lion.data.coins + requested = len(slotids) + required = requested * self.config.get(ScheduleSettings.ScheduleCost.setting_id).value + if required > balance: + error = t(_p( + 'ui:schedule|menu:booking|error:insufficient_balance', + "Booking `{count}` scheduled sessions requires {coin}**{required}**, " + "but you only have {coin}**{balance}**!" + )).format( + count=requested, coin=conf.emojis.coin, required=required, balance=balance + ) + embed = error_embed(error) + else: + # Everything checks out, run the booking + try: + await self.cog.create_booking(self.guildid, self.userid, *slotids) + timestrings = [ + discord.utils.format_dt(slotid_to_utc(slotid), style='T') + for slotid in slotids + ] + ack = t(_np( + 'ui:schedule|menu:booking|success', + "Successfully booked your scheduled session at {times}.", + "Successfully booked the following scheduled sessions.\n{times}", + len(slotids) + )).format( + times='\n'.join(timestrings) + ) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description=ack + ) + except UserInputError as e: + embed = error_embed(e.msg) + await selection.edit_original_response(embed=embed) + self.show_info = False + await self.refresh() + + async def booking_menu_refresh(self): + t = self.bot.translator.t + menu = self.booking_menu + + if self.blacklisted: + placeholder = t(_p( + 'ui:schedule|menu:booking|placeholder:blacklisted', + "Book Sessions (Cannot book - Blacklisted)" + )) + disabled = True + options = [] + else: + disabled = False + placeholder = t(_p( + 'ui:schedule|menu:booking|placeholder:regular', + "Book Sessions ({amount} LC)" + )).format( + coin=conf.emojis.coin, + amount=self.config.get(ScheduleSettings.ScheduleCost.setting_id).value + ) + + # Populate with choices + nowid = self.nowid + upcoming = [nowid + 3600 * i for i in range(1, 25)] + upcoming = [slotid for slotid in upcoming if slotid not in self.schedule] + options = self._format_slot_options(*upcoming) + + menu.placeholder = placeholder + if options: + menu.options = options + menu.disabled = disabled + menu.max_values = len(menu.options) + else: + menu.options = [ + SelectOption(label='None', value='None') + ] + menu.disabled = True + menu.max_values = 1 + + def _format_slot_options(self, *slotids: int) -> list[SelectOption]: + """ + Format provided slotids into Select Options. + + ``` + Today 23:00 (in <1 hour) + Tommorrow 01:00 (in 3 hours) + Today/Tomorrow {start} (in 1 hour) + ``` + """ + t = self.bot.translator.t + options = [] + tz = self.lion.timezone + nowid = self.nowid + now = self.now.astimezone(tz) + + slot_format = t(_p( + 'ui:schedule|menu:slots|option|format', + "{day} {time} (in {until})" + )) + today_name = t(_p( + 'ui:schedule|menu:slots|option|day:today', + "Today" + )) + tomorrow_name = t(_p( + 'ui:schedule|menu:slots|option|day:tomorrow', + "Tomorrow" + )) + + for slotid in slotids: + slot_start = slotid_to_utc(slotid).astimezone(tz) + distance = int((slotid - nowid) // 3600) + until = self._format_until(distance) + day = today_name if (slot_start.day == now.day) else tomorrow_name + name = slot_format.format( + day=day, + time=slot_start.strftime('%H:%M'), + until=until + ) + + options.append(SelectOption(label=name, value=str(slotid))) + return options + + def _format_until(self, distance): + t = self.bot.translator.t + return t(_np( + 'ui:schedule|format_until', + "<1 hour", + "{number} hours", + distance + )).format(number=distance) + + @select(cls=Select, placeholder='CANCEL_MENU_PLACEHOLDER') + async def cancel_menu(self, selection: discord.Interaction, selected): + """ + Cancel the selected slotids. + + Refuses to cancel a slot if it is already running or within one minute of running. + """ + await selection.response.defer(thinking=True, ephemeral=True) + t = self.bot.translator.t + + # Collect slotids that were requested + slotids = list(map(int, selected.values)) + + # Check for 'forbidden' slotids (possible due to long running UI) + now = utc_now() + nowid = time_to_slotid(now) + if nowid in slotids: + error = t(_p( + 'ui:schedule|menu:cancel|error:current_slot', + "You cannot cancel a currently running *scheduled* session! Please attend it if possible." + )) + embed = error_embed(error) + elif (nextid := nowid + 3600) in slotids and (slotid_to_utc(nextid) - now).total_seconds() < 60: + error = t(_p( + 'ui:schedule|menu:cancel|error:too_late', + "Too late! You cannot cancel a scheduled session within a minute of it starting. " + "Please attend it if possible." + )) + embed = error_embed(error) + else: + # Remaining slotids are now cancellable + # Although there is no guarantee the bookings are still valid. + # Request booking cancellation + booking_records = await self.cog.cancel_bookings( + *( + (slotid, self.schedule[slotid].guildid, self.userid) + for slotid in slotids + ) + ) + if not booking_records: + error = t(_p( + 'ui:schedule|menu:cancel|error:already_cancelled', + "The selected bookings no longer exist! Nothing to cancel." + )) + embed = error_embed(error) + else: + timestrings = [ + discord.utils.format_dt(slotid_to_utc(record['slotid']), style='T') + for record in booking_records + ] + ack = t(_np( + 'ui:schedule|menu:cancel|success', + "Successfully cancelled and refunded your scheduled session booking for {times}.", + "Successfully cancelled and refunded your scheduled session bookings:\n{times}.", + len(booking_records) + )).format( + times='\n'.join(timestrings) + ) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description=ack + ) + + await selection.edit_original_response(embed=embed) + self.show_info = False + await self.refresh() + + async def cancel_menu_refresh(self): + t = self.bot.translator.t + menu = self.cancel_menu + + menu.placeholder = t(_p( + 'ui:schedule|menu:cancel|placeholder', + "Cancel booked sessions" + )) + can_cancel = set(self.schedule.keys()) + can_cancel.discard(self.nowid) + menu.options = self._format_slot_options(*can_cancel) + menu.max_values = len(menu.options) + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + t = self.bot.translator.t + # Show booking cost somewhere (in booking menu) + # Show info automatically if member has never booked a session + embed = discord.Embed( + colour=discord.Colour.orange(), + ) + member = self.lion.member + embed.set_author( + name=t(_p( + 'ui:schedule|embed|author', + "Your Scheduled Sessions and Past Statistics" + )).format(name=member.display_name if member else self.lion.luser.data.name), + icon_url=self.lion.member.avatar + ) + if self.show_info: + # Info message + embed.description = t(guide) + else: + # Statistics table + stats_fields = {} + recent_key = t(_p( + 'ui:schedule|embed|field:stats|field:recent', + "Recent" + )) + recent_value = self._format_stats(*self.recent_stats, self.recent_avg) + stats_fields[recent_key] = recent_value + if self.recent_stats[1] == 100: + alltime_key = t(_p( + 'ui:schedule|embed|field:stats|field:alltime', + "All Time" + )) + alltime_value = self._format_stats(*self.all_stats, self.all_avg) + stats_fields[alltime_key] = alltime_value + streak_key = t(_p( + 'ui:schedule|embed|field:stats|field:streak', + "Streak" + )) + if self.streak: + streak_value = t(_np( + 'ui:schedule|embed|field:stats|field:streak|value:zero', + "One session attended! Keep it up!", + "**{streak}** sessions attended in a row! Good job!", + self.streak, + )).format(streak=self.streak) + else: + streak_value = t(_p( + 'ui:schedule|embed|field:stats|field:streak|value:positive', + "No streak yet!" + )) + stats_fields[streak_key] = streak_value + + table = tabulate(*stats_fields.items()) + embed.add_field( + name=t(_p( + 'ui:schedule|embed|field:stats|name', + "Session Statistics" + )), + value='\n'.join(table), + inline=False + ) + + # Upcoming sessions + upcoming = list(self.schedule.values()) + guildids = set(row.guildid for row in upcoming) + show_guild = (len(guildids) > 1) or (self.guildid not in guildids) + + # Split lists in about half if they are too long for one field. + split = math.ceil(len(upcoming) / 2) if len(upcoming) >= 12 else 12 + block1 = upcoming[:split] + block2 = upcoming[split:] + + embed.add_field( + name=t(_p( + 'ui:schedule|embed|field:upcoming|name', + "Upcoming Sessions" + )), + value=self._format_bookings(block1, show_guild) if block1 else t(_p( + 'ui:schedule|embed|field:upcoming|value:empty', + "No sessions scheduled yet!" + )) + ) + if block2: + embed.add_field( + name='-'*5, + value=self._format_bookings(block2, show_guild) + ) + return MessageArgs(embed=embed) + + def _format_stats(self, attended, total, average): + t = self.bot.translator.t + return t(_p( + 'ui:schedule|embed|stats_format', + "**{attended}** attended out of **{total}** booked.\r\n" + "**{percent}%** attendance rate.\r\n" + "**{average}** average attendance time." + )).format( + attended=attended, + total=total, + percent=math.ceil(attended/total * 100) if total else 0, + average=f"{int(average // 60)}:{average % 60:02}" + ) + + def _format_bookings(self, bookings, show_guild=False): + t = self.bot.translator.t + short_format = t(_p( + 'ui:schedule|booking_format:short', + "`in {until}` | {start} - {end}" + )) + long_format = t(_p( + 'ui:schedule|booking_format:long', + "> `in {until}` | {start} - {end}" + )) + items = [] + format = long_format if show_guild else short_format + last_guildid = None + for booking in bookings: + guildid = booking.guildid + data = self.guilds[guildid] + + if last_guildid != guildid: + channel = f"<#{data.lobby_channel}>" + items.append(channel) + last_guildid = guildid + + start = slotid_to_utc(booking.slotid) + end = slotid_to_utc(booking.slotid + 3600) + item = format.format( + until=self._format_until(int((booking.slotid - self.nowid) // 3600)), + start=discord.utils.format_dt(start, style='t'), + end=discord.utils.format_dt(end, style='t'), + ) + items.append(item) + return '\n'.join(items) + + async def refresh_layout(self): + # Don't show cancel menu or clear button if the schedule is empty + await asyncio.gather( + self.clear_button_refresh(), + self.about_button_refresh(), + self.booking_menu_refresh(), + self.cancel_menu_refresh(), + ) + if self.schedule and self.cancel_menu.options: + self.set_layout( + (self.about_button, self.refresh_button, self.clear_button, self.quit_button), + (self.booking_menu,), + (self.cancel_menu,), + ) + else: + self.set_layout( + (self.about_button, self.refresh_button, self.quit_button), + (self.booking_menu,) + ) + + async def reload(self): + now = utc_now() + nowid = time_to_slotid(now) + self.initial_load = self.initial_load and (nowid == self.nowid) + self.now = now + self.nowid = nowid + + if not self.initial_load: + await self._load_member() + await self._load_statistics() + self.show_info = not self.recent_stats[1] + self.initial_load = True + + await self._load_schedule() + + member = self.guild.get_member(self.userid) + blacklist_role = self.config.get(ScheduleSettings.BlacklistRole.setting_id).value + self.blacklisted = member and blacklist_role and (blacklist_role in member.roles) + + async def _load_schedule(self): + """ + Load current member schedule and update guild config cache. + """ + nowid = self.nowid + + booking_model = self.cog.data.ScheduleSessionMember + bookings = await booking_model.fetch_where( + booking_model.slotid >= nowid, + userid=self.userid, + ).order_by('slotid', ORDER.ASC) + guildids = list(set(booking.guildid for booking in bookings)) + guilds = await self.cog.data.ScheduleGuild.fetch_multiple(*guildids) + self.guilds.update(guilds) + self.schedule = { + booking.slotid: booking for booking in bookings + } + + async def _load_member(self): + self.lion = await self.bot.core.lions.fetch_member(self.guildid, self.userid) + await self.lion.data.refresh() + + guild_data = await self.cog.data.ScheduleGuild.fetch_or_create(self.guildid) + self.guilds[self.guildid] = guild_data + self.config = ScheduleConfig(self.guildid, guild_data) + + async def _load_statistics(self): + now = utc_now() + nowid = time_to_slotid(now) + + # Fetch (up to 100) most recent bookings + booking_model = self.cog.data.ScheduleSessionMember + recent = await booking_model.fetch_where( + booking_model.slotid < nowid, + userid=self.userid, + ).order_by('slotid', ORDER.DESC).limit(100) + + # Calculate recent stats + recent_total_clock = 0 + recent_att = 0 + recent_count = len(recent) + streak = 0 + streak_broken = False + for row in recent: + recent_total_clock += row.clock + if row.attended: + recent_att += 1 + if not streak_broken: + streak += 1 + else: + streak_broken = True + + self.recent_stats = (recent_att, recent_count) + self.recent_avg = int(recent_total_clock // (60 * recent_count)) if recent_count else 0 + self.streak = streak + + # Calculate all-time stats + if recent_count == 100: + record = await booking_model.table.select_one_where( + booking_model.slotid < nowid, + userid=self.userid, + ).select( + _booked='COUNT(*)', + _attended='COUNT(*) FILTER (WHERE attended)', + _clocked='SUM(COALESCE(clock, 0))' + ).with_no_adapter() + self.all_stats = (record['_attended'], record['_booked']) + self.all_avg = record['_clocked'] // (60 * record['_booked']) + else: + self.all_stats = self.recent_stats + self.all_avg = self.recent_avg diff --git a/src/modules/schedule/ui/sessionui.py b/src/modules/schedule/ui/sessionui.py new file mode 100644 index 00000000..4d2a737c --- /dev/null +++ b/src/modules/schedule/ui/sessionui.py @@ -0,0 +1,180 @@ +from typing import Optional, TYPE_CHECKING +import asyncio + +import discord +from discord.ui.button import button, Button, ButtonStyle + +from meta import conf, LionBot +from meta.errors import UserInputError +from utils.lib import utc_now +from utils.ui import LeoUI +from babel.translator import ctx_locale + +from .. import babel, logger +from ..lib import slotid_to_utc, time_to_slotid + +from .scheduleui import ScheduleUI + +if TYPE_CHECKING: + from ..cog import ScheduleCog + +_p = babel._p + + +class SessionUI(LeoUI): + # Maybe add a button to check channel permissions + # And make the session update the channel if it is missing permissions + + def __init__(self, bot: LionBot, slotid: int, guildid: int, **kwargs): + kwargs.setdefault('timeout', 3600) + super().__init__(**kwargs) + self.bot = bot + self.cog: 'ScheduleCog' = bot.get_cog('ScheduleCog') + self.slotid = slotid + self.slot_start = slotid_to_utc(slotid) + self.guildid = guildid + self.locale = None + + @property + def starting_soon(self): + return (self.slot_start - utc_now()).total_seconds() < 60 + + async def init_components(self): + """ + Localise components. + """ + lguild = await self.bot.core.lions.fetch_guild(self.guildid) + locale = self.locale = lguild.locale + t = self.bot.translator.t + + self.book_button.label = t(_p( + 'ui:sessionui|button:book|label', + "Book" + ), locale) + self.cancel_button.label = t(_p( + 'ui:sessionui|button:cancel|label', + "Cancel" + ), locale) + self.schedule_button.label = t(_p( + 'ui:sessionui|button:schedule|label', + 'Open Schedule' + ), locale) + + # ----- API ----- + async def reload(self): + await self.init_components() + if self.starting_soon: + # Slot is about to start or slot has already started + self.set_layout((self.schedule_button,)) + else: + self.set_layout( + (self.book_button, self.cancel_button, self.schedule_button), + ) + + # ----- UI Components ----- + @button(label='BOOK_PLACEHOLDER', style=ButtonStyle.blurple) + async def book_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + t = self.bot.translator.t + babel = self.bot.get_cog('BabelCog') + locale = await babel.get_user_locale(press.user.id) + ctx_locale.set(locale) + + error = None + if self.starting_soon: + error = t(_p( + 'ui:session|button:book|error:starting_soon', + "Too late! This session has started or is starting shortly." + )) + else: + schedule = await self.cog._fetch_schedule(press.user.id) + if self.slotid in schedule: + error = t(_p( + 'ui:session|button:book|error:already_booked', + "You are already a member of this session!" + )) + else: + try: + await self.cog.create_booking(self.guildid, press.user.id, self.slotid) + ack = t(_p( + 'ui:session|button:book|success', + "Successfully booked this session." + )) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description=ack + ) + except UserInputError as e: + error = e.msg + if error is not None: + embed = discord.Embed( + colour=discord.Colour.brand_red(), + description=error, + title=t(_p( + 'ui:session|button:book|error|title', + "Could not book session" + )) + ) + + await press.edit_original_response(embed=embed) + + @button(label='CANCEL_PLACHEHOLDER', style=ButtonStyle.blurple) + async def cancel_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + t = self.bot.translator.t + babel = self.bot.get_cog('BabelCog') + locale = await babel.get_user_locale(press.user.id) + ctx_locale.set(locale) + + error = None + if self.starting_soon: + error = t(_p( + 'ui:session|button:cancel|error:starting_soon', + "Too late! This session has started or is starting shortly." + )) + else: + schedule = await self.cog._fetch_schedule(press.user.id) + if self.slotid not in schedule: + error = t(_p( + 'ui:session|button:cancel|error:not_booked', + "You are not a member of this session!" + )) + else: + try: + await self.cog.cancel_bookings( + (self.slotid, self.guildid, press.user.id), + refund=True + ) + ack = t(_p( + 'ui:session|button:cancel|success', + "Successfully cancelled this session." + )) + embed = discord.Embed( + colour=discord.Colour.brand_green(), + description=ack + ) + except UserInputError as e: + error = e.msg + if error is not None: + embed = discord.Embed( + colour=discord.Colour.brand_red(), + description=error, + title=t(_p( + 'ui:session|button:cancel|error|title', + "Could not cancel session" + )) + ) + + await press.edit_original_response(embed=embed) + + @button(label='SCHEDULE_PLACEHOLDER', style=ButtonStyle.blurple) + async def schedule_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + + babel = self.bot.get_cog('BabelCog') + locale = await babel.get_user_locale(press.user.id) + ctx_locale.set(locale) + + ui = ScheduleUI(self.bot, press.guild, press.user.id) + await ui.run(press) + await ui.wait() diff --git a/src/modules/schedule/ui/settingui.py b/src/modules/schedule/ui/settingui.py new file mode 100644 index 00000000..f0799799 --- /dev/null +++ b/src/modules/schedule/ui/settingui.py @@ -0,0 +1,233 @@ +import itertools +import asyncio + +import discord +from discord.ui.button import button, Button, ButtonStyle +from discord.ui.select import select, ChannelSelect, RoleSelect + +from meta import LionBot + +from utils.ui import ConfigUI, DashboardSection +from utils.lib import MessageArgs + +from ..settings import ScheduleSettings +from .. import babel + +_p = babel._p + + +class ScheduleSettingUI(ConfigUI): + pages = [ + ( + ScheduleSettings.SessionLobby, + ScheduleSettings.SessionRoom, + ScheduleSettings.SessionChannels, + ScheduleSettings.ScheduleCost, + ), ( + ScheduleSettings.AttendanceReward, + ScheduleSettings.AttendanceBonus, + ScheduleSettings.MinAttendance, + ), ( + ScheduleSettings.BlacklistRole, + ScheduleSettings.BlacklistAfter, + ) + ] + setting_classes = list(itertools.chain(*pages)) + + def _init_children(self): + # HACK to stop ViewWeights complaining that this UI has too many children + # Children will be correctly initialised after parent init. + return [] + + def __init__(self, bot: LionBot, guildid: int, channelid: int, **kwargs): + self.settings = bot.get_cog('ScheduleCog').settings + super().__init__(bot, guildid, channelid, **kwargs) + self._children = super()._init_children() + self.page_num = 0 + + def get_instance(self, setting): + return next(instance for instance in self.instances if instance.setting_id == setting.setting_id) + + @property + def page_instances(self): + start = sum(len(page) for page in self.pages[:self.page_num]) + end = start + len(self.pages[self.page_num]) + return self.instances[start:end] + + # ----- UI Components ----- + # Page 0 button + @button(label="PAGE0_BUTTON_PLACEHOLDER", style=ButtonStyle.grey) + async def page0_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + self.page_num = 0 + await self.refresh(thinking=press) + + async def page0_button_refresh(self): + t = self.bot.translator.t + self.page0_button.label = t(_p( + 'ui:schedule_config|button:page0|label', + "Page 1" + )) + self.page0_button.disabled = (self.page_num == 0) + + # Lobby channel selector + @select(cls=ChannelSelect, channel_types=[discord.ChannelType.text, discord.ChannelType.voice], + min_values=0, max_values=1, + placeholder='LOBBY_PLACEHOLDER') + async def lobby_menu(self, selection: discord.Interaction, selected: ChannelSelect): + # TODO: Setting value checks + await selection.response.defer() + setting = self.get_instance(ScheduleSettings.SessionLobby) + setting.value = selected.values[0] if selected.values else None + await setting.write() + + async def lobby_menu_refresh(self): + t = self.bot.translator.t + self.lobby_menu.placeholder = t(_p( + 'ui:schedule_config|menu:lobby|placeholder', + "Select Lobby Channel" + )) + + # Room channel selector + @select(cls=ChannelSelect, channel_types=[discord.ChannelType.voice], + min_values=0, max_values=1, + placeholder='ROOM_PLACEHOLDER') + async def room_menu(self, selection: discord.Interaction, selected: ChannelSelect): + await selection.response.defer() + setting = self.get_instance(ScheduleSettings.SessionRoom) + setting.value = selected.values[0] if selected.values else None + await setting.write() + + async def room_menu_refresh(self): + t = self.bot.translator.t + self.room_menu.placeholder = t(_p( + 'ui:schedule_config|menu:room|placeholder', + "Select Session Room" + )) + + # Session channels selector + @select(cls=ChannelSelect, channel_types=[discord.ChannelType.category, discord.ChannelType.voice], + min_values=0, max_values=25, + placeholder='CHANNELS_PLACEHOLDER') + async def channels_menu(self, selection: discord.Interaction, selected: ChannelSelect): + # TODO: Consider XORing input + await selection.response.defer() + setting = self.get_instance(ScheduleSettings.SessionChannels) + setting.value = selected.values + await setting.write() + + async def channels_menu_refresh(self): + t = self.bot.translator.t + self.channels_menu.placeholder = t(_p( + 'ui:schedule_config|menu:channels|placeholder', + "Select Session Channels" + )) + + # Page 1 button + @button(label="PAGE1_BUTTON_PLACEHOLDER", style=ButtonStyle.grey) + async def page1_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + self.page_num = 1 + await self.refresh(thinking=press) + + async def page1_button_refresh(self): + t = self.bot.translator.t + self.page1_button.label = t(_p( + 'ui:schedule_config|button:page1|label', + "Page 2" + )) + self.page1_button.disabled = (self.page_num == 1) + + # Page 3 button + @button(label="PAGE2_BUTTON_PLACEHOLDER", style=ButtonStyle.grey) + async def page2_button(self, press: discord.Interaction, pressed: Button): + await press.response.defer(thinking=True, ephemeral=True) + self.page_num = 2 + await self.refresh(thinking=press) + + async def page2_button_refresh(self): + t = self.bot.translator.t + self.page2_button.label = t(_p( + 'ui:schedule_config|button:page2|label', + "Page 3" + )) + self.page2_button.disabled = (self.page_num == 3) + + # Blacklist role selector + @select(cls=RoleSelect, min_values=0, max_values=1, placeholder="BLACKLIST_ROLE_PLACEHOLDER") + async def blacklist_role_menu(self, selection: discord.Interaction, selected: RoleSelect): + await selection.response.defer() + setting = self.get_instance(ScheduleSettings.BlacklistRole) + setting.value = selected.values[0] if selected.values else None + # TODO: Warning for insufficient permissions? + await setting.write() + + async def blacklist_role_menu_refresh(self): + t = self.bot.translator.t + self.blacklist_role_menu.placeholder = t(_p( + 'ui:schedule_config|menu:blacklist_role|placeholder', + "Select Blacklist Role" + )) + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + t = self.bot.translator.t + title = t(_p( + 'ui:schedule_config|embed|title', + "Scheduled Session Configuration Panel" + )) + embed = discord.Embed( + colour=discord.Colour.orange(), + title=title + ) + for setting in self.page_instances: + embed.add_field(**setting.embed_field, inline=False) + + args = MessageArgs(embed=embed) + return args + + async def refresh_components(self): + await asyncio.gather( + self.page0_button_refresh(), + self.page1_button_refresh(), + self.page2_button_refresh(), + self.edit_button_refresh(), + self.reset_button_refresh(), + self.close_button_refresh(), + ) + if self.page_num == 0: + await asyncio.gather( + self.lobby_menu_refresh(), + self.room_menu_refresh(), + self.channels_menu_refresh(), + ) + self.set_layout( + (self.page0_button, self.page1_button, self.page2_button), + (self.lobby_menu,), + (self.room_menu,), + (self.channels_menu,), + (self.edit_button, self.reset_button, self.close_button), + ) + elif self.page_num == 1: + self.set_layout( + (self.page0_button, self.page1_button, self.page2_button), + (self.edit_button, self.reset_button, self.close_button), + ) + elif self.page_num == 2: + await asyncio.gather( + self.blacklist_role_menu_refresh() + ) + self.set_layout( + (self.page0_button, self.page1_button, self.page2_button), + (self.blacklist_role_menu,), + (self.edit_button, self.reset_button, self.close_button), + ) + + +class ScheduleDashboard(DashboardSection): + section_name = _p( + 'dash:schedule|title', + "Scheduled Session Configuration ({commands[configure schedule]})" + ) + configui = ScheduleSettingUI + setting_classes = ScheduleSettingUI.setting_classes diff --git a/src/modules/statistics/data.py b/src/modules/statistics/data.py index a26ddeb2..1e939204 100644 --- a/src/modules/statistics/data.py +++ b/src/modules/statistics/data.py @@ -1,4 +1,5 @@ from typing import Optional, Iterable +import datetime as dt from enum import Enum from itertools import chain from psycopg import sql @@ -78,6 +79,38 @@ class StatsData(Registry): duration = Integer() end_time = Timestamp() + @classmethod + async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]): + query = sql.SQL( + """ + SELECT + t._guildid AS guildid, + t._userid AS userid, + t._start AS start_time, + t._end AS end_time, + study_time_between(t._guildid, t._userid, t._start, t._end) AS stime + FROM + (VALUES {}) + AS + t (_guildid, _userid, _start, _end) + """ + ).format( + sql.SQL(', ').join( + sql.SQL("({}, {}, {}, {})").format( + sql.Placeholder(), sql.Placeholder(), + sql.Placeholder(), sql.Placeholder() + ) + for _ in points + ) + ) + conn = await cls._connector.get_connection() + async with conn.cursor() as cursor: + await cursor.execute( + query, + chain(*points) + ) + return cursor.fetchall() + @classmethod async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int: conn = await cls._connector.get_connection() diff --git a/src/settings/setting_types.py b/src/settings/setting_types.py index f38cbe8c..95941680 100644 --- a/src/settings/setting_types.py +++ b/src/settings/setting_types.py @@ -291,7 +291,7 @@ class RoleSetting(InteractiveSetting[ParentID, int, Union[discord.Role, discord. role = guild.get_role(data) if role is None: role = discord.Object(id=data) - return role + return role @classmethod async def _parse_string(cls, parent_id, string: str, **kwargs): diff --git a/src/tracking/voice/cog.py b/src/tracking/voice/cog.py index 1b5f665e..30040648 100644 --- a/src/tracking/voice/cog.py +++ b/src/tracking/voice/cog.py @@ -37,6 +37,8 @@ class VoiceTrackerCog(LionCog): self.babel = babel # State + # Flag indicating whether local voice sessions have been initialised + self.initialised = asyncio.Event() self.handle_events = False self.tracking_lock = asyncio.Lock() @@ -92,6 +94,7 @@ class VoiceTrackerCog(LionCog): logger.debug("Disabling voice state event handling.") self.handle_events = False + self.initialised.clear() # Read and save the tracked voice states of all visible voice channels voice_members = {} # (guildid, userid) -> TrackedVoiceState voice_guilds = set() @@ -252,6 +255,7 @@ class VoiceTrackerCog(LionCog): for row in rows: VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)]) logger.info(f"Started {len(rows)} new voice sessions from voice channels!") + self.initialised.set() @LionCog.listener("on_voice_state_update") @log_wrap(action='Voice Track') @@ -259,7 +263,6 @@ class VoiceTrackerCog(LionCog): """ Spawns the correct tasks from members joining, leaving, and changing live state. """ - # TODO: Logging context if not self.handle_events: # Rely on initialisation to handle current state return @@ -505,7 +508,7 @@ class VoiceTrackerCog(LionCog): delay = (tomorrow - now).total_seconds() else: start_time = now - delay = 60 + delay = 20 expiry = start_time + dt.timedelta(seconds=cap) if expiry >= tomorrow: diff --git a/src/tracking/voice/session.py b/src/tracking/voice/session.py index 00b5d7ae..160dafb6 100644 --- a/src/tracking/voice/session.py +++ b/src/tracking/voice/session.py @@ -173,6 +173,7 @@ class VoiceSession: live_video=state.video, hourly_coins=self.hourly_rate ) + self.bot.dispatch('voice_session_start', self.data) self.start_task = None def schedule_expiry(self, expire_time): @@ -230,7 +231,11 @@ class VoiceSession: """ if self.activity is SessionState.ONGOING: # End the ongoing session - await self.data.close_study_session_at(self.guildid, self.userid, utc_now()) + now = utc_now() + await self.data.close_study_session_at(self.guildid, self.userid, now) + + # TODO: Something a bit saner/safer.. dispatch the finished session instead? + self.bot.dispatch('voice_session_end', self.data, now) # Rank update # TODO: Change to broadcasted event? diff --git a/src/utils/data.py b/src/utils/data.py new file mode 100644 index 00000000..2a8795b0 --- /dev/null +++ b/src/utils/data.py @@ -0,0 +1,162 @@ +""" +Some useful pre-built Conditions for data queries. +""" +from typing import Optional +from itertools import chain + +from psycopg import sql +from data.conditions import Condition, Joiner +from data.columns import ColumnExpr +from data.base import Expression +from constants import MAX_COINS + + +def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[...]) -> Condition: + """ + Condition constructor for filtering by multiple column equalities. + + Example Usage + ------------- + Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4))) + """ + if not data: + raise ValueError("Cannot create empty multivalue condition.") + left = sql.SQL("({})").format( + sql.SQL(', ').join( + sql.Identifier(key) + for key in columns + ) + ) + right_item = sql.SQL('({})').format( + sql.SQL(', ').join( + sql.Placeholder() + for _ in columns + ) + ) + right = sql.SQL("({})").format( + sql.SQL(', ').join( + right_item + for _ in data + ) + ) + return Condition( + left, + Joiner.IN, + right, + chain(*data) + ) + + +def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition: + """ + Condition constructor for filtering member tables by guild and user id simultaneously. + + Example Usage + ------------- + Query.where(MEMBERS((1234,12), (5678,34))) + """ + if not memberids: + raise ValueError("Cannot create a condition with no members") + return Condition( + sql.SQL("({guildid}, {userid})").format( + guildid=sql.Identifier(guild_column), + userid=sql.Identifier(user_column) + ), + Joiner.IN, + sql.SQL("({})").format( + sql.SQL(', ').join( + sql.SQL("({}, {})").format( + sql.Placeholder(), + sql.Placeholder() + ) for _ in memberids + ) + ), + chain(*memberids) + ) + + +def as_duration(expr: Expression) -> ColumnExpr: + """ + Convert an integer expression into a duration expression. + """ + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("({} * interval '1 second')").format(expr_expr), + expr_values + ) + + +class TemporaryTable(Expression): + """ + Create a temporary table expression to be used in From or With clauses. + + Example + ------- + ``` + tmp_table = TemporaryTable('_col1', '_col2', name='data') + tmp_table.values((1, 2), (3, 4)) + + real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table) + ``` + """ + + def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str]] = None): + self.name = name + self.columns = columns + self.types = types + if types and len(types) != len(columns): + raise ValueError("Number of types does not much number of columns!") + + self._table_columns = { + col: ColumnExpr(sql.Identifier(name, col)) + for col in columns + } + + self.values = [] + + def __getitem__(self, key) -> sql.Identifier: + return self._table_columns[key] + + def as_tuple(self): + """ + (VALUES {}) + AS + name (col1, col2) + """ + single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns)) + if self.types: + first_value = sql.SQL("({})").format( + sql.SQL(", ").join( + sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast)) + for cast in self.types + ) + ) + else: + first_value = single_value + + value_placeholder = sql.SQL("(VALUES {})").format( + sql.SQL(", ").join( + (first_value, *(single_value for _ in self.values[1:])) + ) + ) + expr = sql.SQL("{values} AS {name} ({columns})").format( + values=value_placeholder, + name=sql.Identifier(self.name), + columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns) + ) + values = chain(*self.values) + return (expr, values) + + def set_values(self, *data): + self.values = data + + +def SAFECOINS(expr: Expression) -> Expression: + expr_expr, expr_values = expr.as_tuple() + return ColumnExpr( + sql.SQL("LEAST({}, {})").format( + expr_expr, + sql.Literal(MAX_COINS) + ), + expr_values + ) diff --git a/src/utils/ratelimits.py b/src/utils/ratelimits.py index a18e8bb9..a865799c 100644 --- a/src/utils/ratelimits.py +++ b/src/utils/ratelimits.py @@ -85,6 +85,7 @@ class Bucket: # Wrapped in a lock so that waiters are correctly handled in wait-order # Otherwise multiple waiters will have the same delay, # and race for the wakeup after sleep. + # Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order async with self._wait_lock: # We do this in a loop in case asyncio.sleep throws us out early, # or a synchronous request overflows the bucket while we are waiting. diff --git a/src/utils/ui/config.py b/src/utils/ui/config.py index 99e91bd5..06504655 100644 --- a/src/utils/ui/config.py +++ b/src/utils/ui/config.py @@ -52,6 +52,10 @@ class ConfigUI(LeoUI): # Instances of the settings this UI is managing self.instances = () + @property + def page_instances(self): + return self.instances + async def interaction_check(self, interaction: discord.Interaction): """ Default requirement for a Config UI is low management (i.e. manage_guild permissions). @@ -95,7 +99,7 @@ class ConfigUI(LeoUI): Errors should raise instances of `UserInputError`, and will be caught for retry. """ t = ctx_translator.get().t - instances = self.instances + instances = self.page_instances items = [setting.input_field for setting in instances] # Filter out settings which don't have input fields items = [item for item in items if item] @@ -174,7 +178,7 @@ class ConfigUI(LeoUI): """ await press.response.defer() - for instance in self.instances: + for instance in self.page_instances: instance.data = None await instance.write()