diff --git a/src/tracking/text/__init__.py b/src/tracking/text/__init__.py new file mode 100644 index 00000000..ffae887d --- /dev/null +++ b/src/tracking/text/__init__.py @@ -0,0 +1,11 @@ +import logging +from babel.translator import LocalBabel + +logger = logging.getLogger(__name__) +babel = LocalBabel('text-tracker') + + +async def setup(bot): + from .cog import TextTrackerCog + + await bot.add_cog(TextTrackerCog(bot)) diff --git a/src/tracking/text/cog.py b/src/tracking/text/cog.py new file mode 100644 index 00000000..e0b39737 --- /dev/null +++ b/src/tracking/text/cog.py @@ -0,0 +1,361 @@ +from typing import Optional +import asyncio +import time +import datetime as dt +from collections import defaultdict + +import discord +from discord.ext import commands as cmds +from discord import app_commands as appcmds + +from meta import LionBot, LionCog, LionContext, conf +from meta.errors import UserInputError +from meta.logger import log_wrap, logging_context +from meta.sharding import THIS_SHARD +from meta.app import appname +from utils.lib import utc_now, error_embed + +from wards import low_management +from . import babel, logger +from .data import TextTrackerData + +from .session import TextSession +from .settings import TextTrackerSettings, TextTrackerGlobalSettings +from .ui import TextTrackerConfigUI + + +_p = babel._p + + +class TextTrackerCog(LionCog): + """ + LionCog module controlling and configuring the text tracking system. + """ + # Maximum number of completed sessions to batch before processing + batchsize = conf.text_tracker.getint('batchsize') + + # Maximum time to processing for a completed session + batchtime = conf.text_tracker.getint('batchtime') + + def __init__(self, bot: LionBot): + self.bot = bot + self.data = bot.db.load_registry(TextTrackerData()) + self.settings = TextTrackerSettings() + self.global_settings = TextTrackerGlobalSettings() + self.babel = babel + + self.sessionq = asyncio.Queue(maxsize=0) + + # Map of ongoing text sessions + # guildid -> (userid -> TextSession) + self.ongoing = defaultdict(dict) + + self._consumer_task = None + + self.untracked_channels = self.settings.UntrackedTextChannels._cache + + async def cog_load(self): + await self.data.init() + + self.bot.core.guild_config.register_model_setting(self.settings.XPPerPeriod) + self.bot.core.guild_config.register_model_setting(self.settings.WordXP) + self.bot.core.guild_config.register_setting(self.settings.UntrackedTextChannels) + + self.global_xp_per_period = await self.global_settings.XPPerPeriod.get(appname) + self.global_word_xp = await self.global_settings.WordXP.get(appname) + + leo_setting_cog = self.bot.get_cog('LeoSettings') + leo_setting_cog.bot_setting_groups.append(self.global_settings) + self.crossload_group(self.leo_configure_group, leo_setting_cog.leo_configure_group) + + # Update the untracked text channel cache + await self.settings.UntrackedTextChannels.setup(self.bot) + + configcog = self.bot.get_cog('ConfigCog') + if configcog is None: + logger.critical( + "Attempting to load the TextTrackerCog before ConfigCog! Failed to crossload configuration group." + ) + else: + self.crossload_group(self.configure_group, configcog.configure_group) + + if self.bot.is_ready(): + await self.initialise() + + async def cog_unload(self): + if self._consumer_task is not None: + self._consumer_task.cancel() + + @log_wrap(stack=['Text Sessions', 'Finished']) + async def session_handler(self, session: TextSession): + """ + Callback used to process a completed session. + + Places the session into the completed queue and removes it from the session cache. + """ + cached = self.ongoing[session.guildid].pop(session.userid, None) + if cached is not session: + raise ValueError("Sync error, completed session does not match cached session!") + logger.debug( + "Ending text session: {session!r}".format( + session=session + ) + ) + self.sessionq.put_nowait(session) + + @log_wrap(stack=['Text Sessions', 'Message Event']) + async def _session_consumer(self): + """ + Process completed sessions in batches of length `batchsize`. + """ + # Number of sessions in the batch + counter = 0 + batch = [] + last_time = time.monotonic() + + closing = False + while not closing: + try: + session = await self.sessionq.get() + batch.append(session) + counter += 1 + except asyncio.CancelledError: + # Attempt to process the rest of the batch, then close + closing = True + + if counter >= self.batchsize or time.monotonic() - last_time > self.batchtime or closing: + if batch: + try: + await self._process_batch(batch) + except Exception: + logger.exception( + "Unknown exception processing batch of text sessions! Discarding and continuing." + ) + batch = [] + counter = 0 + last_time = time.monotonic() + + async def _process_batch(self, batch): + """ + Process a batch of completed text sessions. + + Handles economy calculations. + """ + if not batch: + raise ValueError("Cannot process empty batch!") + + logger.info( + f"Saving batch of {len(batch)} completed text sessions." + ) + + # Batch-fetch lguilds + lguilds = await self.bot.core.lions.fetch_guilds(*(session.guildid for session in batch)) + + # Build data + rows = [] + for sess in batch: + # TODO: XP and coin calculations from settings + # Note that XP is calculated here rather than directly through the DB + # to support both XP and economy dynamic bonuses. + + globalxp = ( + sess.total_periods * self.global_xp_per_period.value + + self.global_word_xp.value * sess.total_words / 100 + ) + + lguild = lguilds[sess.guildid] + periodxp = lguild.config.get('xp_per_period').value + wordxp = lguild.config.get('word_xp').value + xpcoins = lguild.config.get('coins_per_xp').value + guildxp = ( + sess.total_periods * periodxp + + wordxp * sess.total_words / 100 + ) + coins = xpcoins * guildxp / 100 + rows.append(( + sess.guildid, sess.userid, + sess.start_time, sess.duration, + sess.total_messages, sess.total_words, sess.total_periods, + int(guildxp), int(globalxp), + int(coins) + )) + + # Submit to batch data handler + # TODO: error handling + await self.data.TextSessions.end_sessions(self.bot.db, *rows) + rank_cog = self.bot.get_cog('RankCog') + if rank_cog: + await rank_cog.on_message_session_complete( + *((rows[0], rows[1], rows[4], rows[7]) for rows in rows) + ) + + @LionCog.listener('on_ready') + @log_wrap(action='Init Text Sessions') + async def initialise(self): + """ + Launch the session consumer. + """ + if self._consumer_task and not self._consumer_task.cancelled(): + self._consumer_task.cancel() + self._consumer_task = asyncio.create_task(self._session_consumer()) + logger.info("Launched text session consumer.") + + @LionCog.listener('on_message') + @log_wrap(stack=['Text Sessions', 'Message Event']) + async def text_message_handler(self, message): + """ + Message event handler for the text session tracker. + + Process the handled message through a text session, + creating it if required. + """ + # Initial wards + if message.author.bot: + return + if not message.guild: + return + # TODO: Blacklisted ward + + guildid = message.guild.id + channel = message.channel + # Untracked channel ward + untracked = self.untracked_channels.get(guildid, []) + if channel.id in untracked or (channel.category_id and channel.category_id in untracked): + return + + # Identify whether a session already exists for this member + guild_sessions = self.ongoing[guildid] + if (session := guild_sessions.get(message.author.id, None)) is None: + with logging_context(context=f"mid: {message.id}"): + session = TextSession.from_message(message) + session.on_finish(self.session_handler) + guild_sessions[message.author.id] = session + logger.debug( + "Launched new text session: {session!r}".format( + session=session + ) + ) + session.process(message) + + # -------- Configuration Commands -------- + @LionCog.placeholder_group + @cmds.hybrid_group('configure', with_app_command=False) + async def configure_group(self, ctx: LionContext): + # Placeholder group method, not used + pass + + @configure_group.command( + name=_p('cmd:configure_message_exp', "message_exp"), + description=_p( + 'cmd:configure_message_exp|desc', + "Configure Message Tracking & Experience" + ) + ) + @appcmds.rename( + xp_per_period=TextTrackerSettings.XPPerPeriod._display_name, + word_xp=TextTrackerSettings.WordXP._display_name, + ) + @appcmds.describe( + xp_per_period=TextTrackerSettings.XPPerPeriod._desc, + word_xp=TextTrackerSettings.WordXP._desc, + ) + @cmds.check(low_management) + async def configure_text_tracking_cmd(self, ctx: LionContext, + xp_per_period: Optional[appcmds.Range[int, 0, 2**15]] = None, + word_xp: Optional[appcmds.Range[int, 0, 2**15]] = None): + """ + Guild configuration command to view and configure the text tracker settings. + """ + # Standard type checking guards + if not ctx.guild: + return + if not ctx.interaction: + return + + # Retrieve and initialise settings + setting_xp_period = ctx.lguild.config.get('xp_per_period') + setting_word_xp = ctx.lguild.config.get('word_xp') + + modified = [] + if xp_per_period is not None and setting_xp_period._data != xp_per_period: + setting_xp_period.data = xp_per_period + await setting_xp_period.write() + modified.append(setting_xp_period) + if word_xp is not None and setting_word_xp._data != word_xp: + setting_word_xp.data = word_xp + await setting_word_xp.write() + modified.append(setting_word_xp) + + # Send update ack if required + if modified: + desc = '\n'.join(f"{conf.emojis.tick} {setting.update_message}" for setting in modified) + await ctx.reply( + embed=discord.Embed( + colour=discord.Colour.green(), + description=desc + ) + ) + + if ctx.channel.id not in TextTrackerConfigUI._listening or not modified: + # Display setting group UI + configui = TextTrackerConfigUI(self.bot, ctx.guild.id, ctx.channel.id) + await configui.run(ctx.interaction) + await configui.wait() + + # -------- Global Configuration Commands -------- + @LionCog.placeholder_group + @cmds.hybrid_group('leo_configure', with_app_command=False) + async def leo_configure_group(self, ctx: LionContext): + # Placeholder group method, not used + pass + + @leo_configure_group.command( + name=_p('cmd:leo_configure_exp_rates', "experience_rates"), + description=_p( + 'cmd:leo_configure_exp_rates|desc', + "Global experience rate configuration" + ) + ) + @appcmds.rename( + xp_per_period=TextTrackerGlobalSettings.XPPerPeriod._display_name, + word_xp=TextTrackerGlobalSettings.WordXP._display_name, + ) + @appcmds.describe( + xp_per_period=TextTrackerGlobalSettings.XPPerPeriod._desc, + word_xp=TextTrackerGlobalSettings.WordXP._desc, + ) + async def leo_configure_text_tracking_cmd(self, ctx: LionContext, + xp_per_period: Optional[appcmds.Range[int, 0, 2**15]] = None, + word_xp: Optional[appcmds.Range[int, 0, 2**15]] = None): + """ + Global configuration panel for text tracking global XP. + """ + setting_xp_period = self.global_xp_per_period + setting_word_xp = self.global_word_xp + + modified = [] + if word_xp is not None and word_xp != setting_word_xp._data: + setting_word_xp.value = word_xp + await setting_word_xp.write() + modified.append(setting_word_xp) + if xp_per_period is not None and xp_per_period != setting_xp_period._data: + setting_xp_period.value = xp_per_period + await setting_xp_period.write() + modified.append(setting_xp_period) + + if modified: + desc = '\n'.join(f"{conf.emojis.tick} {setting.update_message}" for setting in modified) + await ctx.reply( + embed=discord.Embed( + colour=discord.Colour.green(), + description=desc + ) + ) + else: + embed = discord.Embed( + colour=discord.Colour.orange(), + title="Configure Global XP" + ) + embed.add_field(**setting_xp_period.embed_field, inline=False) + embed.add_field(**setting_word_xp.embed_field, inline=False) + await ctx.reply(embed=embed) diff --git a/src/tracking/text/data.py b/src/tracking/text/data.py new file mode 100644 index 00000000..3a35f431 --- /dev/null +++ b/src/tracking/text/data.py @@ -0,0 +1,288 @@ +from itertools import chain +from psycopg import sql + + +from data import RowModel, Registry, Table +from data.columns import Integer, String, Timestamp, Bool + +from core.data import CoreData + + +class TextTrackerData(Registry): + class BotConfigText(RowModel): + """ + App configuration for text tracker XP. + + Schema + ------ + CREATE TABLE bot_config_experience_rates( + appname TEXT PRIMARY KEY REFERENCES bot_config(appname) ON DELETE CASCADE, + period_length INTEGER, + xp_per_period INTEGER, + xp_per_centiword INTEGER + ); + + """ + _tablename_ = 'bot_config_experience_rates' + _cache_ = {} + + appname = String(primary=True) + period_length = Integer() + xp_per_period = Integer() + xp_per_centiword = Integer() + + class TextSessions(RowModel): + """ + Model describing completed text chat sessions. + + Schema + ------ + CREATE TABLE text_sessions( + sessionid BIGSERIAL PRIMARY KEY, + guildid BIGINT NOT NULL, + userid BIGINT NOT NULL, + start_time TIMESTAMPTZ NOT NULL, + duration INTEGER NOT NULL, + messages INTEGER NOT NULL, + words INTEGER NOT NULL, + periods INTEGER NOT NULL, + user_expid BIGINT REFERENCES user_experience, + member_expid BIGINT REFERENCES member_experience, + end_time TIMESTAMP GENERATED ALWAYS AS + ((start_time AT TIME ZONE 'UTC') + duration * interval '1 second') + STORED, + FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE + ); + CREATE INDEX text_sessions_members ON text_sessions (guildid, userid); + CREATE INDEX text_sessions_start_time ON text_sessions (start_time); + CREATE INDEX text_sessions_end_time ON text_sessions (end_time); + """ + _tablename_ = 'text_sessions' + + sessionid = Integer(primary=True) + guildid = Integer() + userid = Integer() + start_time = Timestamp() + duration = Integer() + messages = Integer() + words = Integer() + periods = Integer() + end_time = Timestamp() + user_expid = Integer() + member_expid = Integer() + + @classmethod + async def end_sessions(cls, connector, *session_data): + query = sql.SQL(""" + WITH + data ( + _guildid, _userid, + _start_time, _duration, + _messages, _words, _periods, + _memberxp, _userxp, + _coins + ) + AS + (VALUES {}) + , transactions AS ( + INSERT INTO coin_transactions ( + guildid, actorid, + from_account, to_account, + amount, bonus, transactiontype + ) SELECT + data._guildid, 0, + NULL, data._userid, + SUM(_coins), 0, 'TEXT_SESSION' + FROM data + WHERE data._coins > 0 + GROUP BY (data._guildid, data._userid) + RETURNING guildid, to_account AS userid, amount, transactionid + ) + , member AS ( + UPDATE members + SET coins = coins + data._coins + FROM data + WHERE members.userid = data._userid AND members.guildid = data._guildid + ) + , member_exp AS ( + INSERT INTO member_experience ( + guildid, userid, + earned_at, + amount, exp_type, + transactionid + ) SELECT + data._guildid, data._userid, + MAX(data._start_time), + SUM(data._memberxp), 'TEXT_XP', + transactions.transactionid + FROM data + LEFT JOIN transactions ON + data._userid = transactions.userid AND + data._guildid = transactions.guildid + WHERE data._memberxp > 0 + GROUP BY (data._guildid, data._userid, transactions.transactionid) + RETURNING guildid, userid, member_expid + ) + , user_exp AS( + INSERT INTO user_experience ( + userid, + earned_at, + amount, exp_type + ) SELECT + data._userid, + MAX(data._start_time), + SUM(data._userxp), 'TEXT_XP' + FROM data + WHERE data._userxp > 0 + GROUP BY (data._userid) + RETURNING userid, user_expid + ) + INSERT INTO text_sessions( + guildid, userid, + start_time, duration, + messages, words, periods, + user_expid, member_expid + ) SELECT + data._guildid, data._userid, + data._start_time, data._duration, + data._messages, data._words, data._periods, + user_exp.user_expid, member_exp.member_expid + FROM data + LEFT JOIN member_exp ON data._userid = member_exp.userid AND data._guildid = member_exp.guildid + LEFT JOIN user_exp ON data._userid = user_exp.userid + """).format( + sql.SQL(', ').join( + sql.SQL("({}, {}, {}, {}, {}, {}, {}, {}, {}, {})").format( + sql.Placeholder(), sql.Placeholder(), + sql.Placeholder(), sql.Placeholder(), + sql.Placeholder(), sql.Placeholder(), sql.Placeholder(), + sql.Placeholder(), sql.Placeholder(), + sql.Placeholder(), + ) + for _ in session_data + ) + ) + # TODO: Consider asking for a *new* temporary connection here, to avoid blocking + # Or ask for a connection from the connection pool + # Transaction may take some time due to index updates + # Alternatively maybe use the "do not expect response mode" + conn = await connector.get_connection() + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*session_data)) + ) + return + + @classmethod + async def user_messages_between(cls, userid: int, *points): + """ + Compute messages written between the given points. + """ + blocks = zip(points, points[1:]) + query = sql.SQL( + """ + SELECT + ( + SELECT + SUM(messages) + FROM text_sessions s + WHERE + s.userid = %s + AND s.start_time >= periods._start + AND s.start_time < periods._end + ) AS period_m + FROM + (VALUES {}) + AS + periods (_start, _end) + ORDER BY periods._start + """ + ).format( + sql.SQL(', ').join( + sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] + ) + ) + conn = await cls._connector.get_connection() + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid,), *blocks)) + ) + return [r['period_m'] or 0 for r in await cursor.fetchall()] + + @classmethod + async def member_messages_between(cls, guildid: int, userid: int, *points): + """ + Compute messages written between the given points. + """ + blocks = zip(points, points[1:]) + query = sql.SQL( + """ + SELECT + ( + SELECT + SUM(messages) + FROM text_sessions s + WHERE + s.userid = %s + AND s.guildid = %s + AND s.start_time >= periods._start + AND s.start_time < periods._end + ) AS period_m + FROM + (VALUES {}) + AS + periods (_start, _end) + ORDER BY periods._start + """ + ).format( + sql.SQL(', ').join( + sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] + ) + ) + conn = await cls._connector.get_connection() + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid, guildid), *blocks)) + ) + return [r['period_m'] or 0 for r in await cursor.fetchall()] + + @classmethod + async def member_messages_since(cls, guildid: int, userid: int, *points): + """ + Compute messages written between the given points. + """ + query = sql.SQL( + """ + SELECT + ( + SELECT + SUM(messages) + FROM text_sessions s + WHERE + s.userid = %s + AND s.guildid = %s + AND s.start_time >= t._start + ) AS messages + FROM + (VALUES {}) + AS + t (_start) + ORDER BY t._start + """ + ).format( + sql.SQL(', ').join( + sql.SQL("({})").format(sql.Placeholder()) for _ in points + ) + ) + conn = await cls._connector.get_connection() + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid, guildid), points)) + ) + return [r['messages'] or 0 for r in await cursor.fetchall()] + + untracked_channels = Table('untracked_text_channels') diff --git a/src/tracking/text/session.py b/src/tracking/text/session.py new file mode 100644 index 00000000..6f3f16ea --- /dev/null +++ b/src/tracking/text/session.py @@ -0,0 +1,182 @@ +from typing import Optional +import asyncio +import datetime as dt + +import discord + +from utils.lib import utc_now + + +class TextSession: + """ + Represents an ongoing text session for a single member. + + Attributes + ---------- + userid + guildid + start_time + total_messages + total_words + total_periods + this_period_start + this_period_messages + this_period_words + timeout + """ + __slots__ = ( + 'userid', 'guildid', + 'start_time', 'end_time', + 'total_messages', 'total_words', 'total_periods', + 'this_period_start', 'this_period_messages', 'this_period_words', + 'last_message_at', 'timeout_task', + 'finish_callback', 'finish_task', 'finished', 'finished_at', + ) + + # Length of a single period + # period_length = 5 * 60 + period_length = 10 + timeout_length = 2 * period_length + + # Maximum length of a session + # session_length = 60 * 60 + session_length = 120 + + def __init__(self, userid, guildid, start_time): + self.userid = userid + self.guildid = guildid + + self.start_time = start_time + self.end_time = start_time + dt.timedelta(seconds=self.session_length) + + self.total_messages = 0 + self.total_words = 0 + self.total_periods = 0 + + self.this_period_start = start_time + self.this_period_messages = 0 + self.this_period_words = 0 + + self.last_message_at = None + self.timeout_task = None + + self.finish_callback = None + self.finish_task = None + self.finished = asyncio.Event() + self.finished_at = None + + @property + def duration(self) -> int: + if self.start_time is None: + raise ValueError("Cannot take duration of uninitialised session!") + + end = self.finished_at or utc_now() + return int((end - self.start_time).total_seconds()) + + def __repr__(self): + return ( + "(" + "{self.__class__.__name__}: " + "userid={self.userid}, " + "guildid={self.guildid}, " + "start_time={self.start_time}, " + "end_time={self.end_time}, " + "total_messages={self.total_messages}, " + "total_words={self.total_words}, " + "total_periods={self.total_periods}, " + "last_message_at={self.last_message_at}, " + "finished_at={self.finished_at}" + ")" + ).format(self=self) + + @classmethod + def from_message(cls, message: discord.Message): + """ + Instantiate a new TextSession from an initial discord message. + + Does not process the given message. + """ + if not message.guild: + raise ValueError("Cannot initialise from message outside of Guild context!") + self = cls(message.author.id, message.guild.id, message.created_at) + return self + + def process(self, message: discord.Message): + """ + Process a message into the session. + """ + if (message.author.id != self.userid) or (message.guild.id != self.guildid): + raise ValueError("Invalid attempt to process message from a different member!") + + # Identify if we need to start a new period + tdiff = (message.created_at - self.this_period_start).total_seconds() + if self.this_period_start is not None and tdiff < self.period_length: + self.this_period_messages += 1 + self.this_period_words += len(message.content.split()) + else: + self.roll_period() + self.this_period_start = message.created_at + self.this_period_messages = 1 + self.this_period_words = len(message.content.split()) + self.last_message_at = message.created_at + + # Update the session expiry + self._reschedule_timeout(self.last_message_at + dt.timedelta(seconds=self.timeout_length)) + + def roll_period(self): + """ + Add pending stats from the current period, and start a new period. + """ + if self.this_period_messages: + self.total_messages += self.this_period_messages + self.total_words += self.this_period_words + self.total_periods += 1 + self.this_period_start = None + + async def finish(self): + """ + Finalise the session and set the finished event. Idempotent. + + Also calls the registered finish callback, if set. + """ + if self.finished.is_set(): + return + + self.roll_period() + self.finished_at = self.last_message_at or utc_now() + + self.finished.set() + if self.finish_callback: + await self.finish_callback(self) + + async def cancel(self): + """ + Cancel this session. + + Does not execute the finish_callback. + """ + ... + + def on_finish(self, callback): + """ + Register a callback coroutine to be executed when the session finishes. + """ + self.finish_callback = callback + + async def _timeout(self, diff): + if diff > 0: + await asyncio.sleep(diff) + await asyncio.shield(self.finish()) + + def _reschedule_timeout(self, target_time): + """ + Schedule the finish timeout for the given target time. + """ + if self.finished.is_set(): + return + if self.finish_task and not self.finish_task.cancelled(): + self.finish_task.cancel() + + target_time = min(self.end_time, target_time) + dist = (target_time - utc_now()).total_seconds() + self.finish_task = asyncio.create_task(self._timeout(dist)) diff --git a/src/tracking/text/settings.py b/src/tracking/text/settings.py new file mode 100644 index 00000000..8379a5d2 --- /dev/null +++ b/src/tracking/text/settings.py @@ -0,0 +1,180 @@ +from typing import Optional +import asyncio +from collections import defaultdict + +from settings.groups import SettingGroup +from settings.data import ModelData, ListData +from settings.setting_types import ChannelListSetting, IntegerSetting + +from meta.config import conf +from meta.sharding import THIS_SHARD +from meta.logger import log_wrap +from core.data import CoreData +from babel.translator import ctx_translator + +from . import babel, logger +from .data import TextTrackerData + +_p = babel._p + + +class TextTrackerSettings(SettingGroup): + """ + Guild settings: + xp per period (guild_config.period_xp) + additional xp per hundred words (guild_config.word_xp) + coins per hundred xp (guild_config.xp_coins) + untracked channels (untracked_text_channels(channelid PK, guildid FK)) + """ + class XPPerPeriod(ModelData, IntegerSetting): + setting_id = 'xp_per_period' + + _display_name = _p('guildset:xp_per_period', "xp_per_5min") + _desc = _p( + 'guildset:xp_per_period|desc', + "How much XP members will be given every 5 minute period they are active." + ) + _long_desc = _p( + 'guildset:xp_per_period|long_desc', + "Amount of message XP to give members for each 5 minute period in which they are active (send a message). " + "Note that this XP is only given *once* per period." + ) + _default = 101 # TODO: Make a dynamic default based on the global setting? + + _model = CoreData.Guild + _column = CoreData.Guild.xp_per_period.name + + @property + def update_message(self): + t = ctx_translator.get().t + return t(_p( + 'guildset:xp_per_period|set_response', + "For every **5** minutes they are active (i.e. in which they send a message), " + "members will now be given **{amount}** XP." + )).format(amount=self.value) + + class WordXP(ModelData, IntegerSetting): + setting_id = 'word_xp' + + _display_name = _p('guildset:word_xp', "xp_per_100words") + _desc = _p( + 'guildset:word_xp|desc', + "How much XP members will be given per hundred words they write." + ) + _long_desc = _p( + 'guildset:word_xp|long_desc', + "Amount of message XP to be given (additionally to the XP per period) for each hundred words. " + "Useful for rewarding communication." + ) + _default = 50 + + _model = CoreData.Guild + _column = CoreData.Guild.xp_per_centiword.name + + @property + def update_message(self): + t = ctx_translator.get().t + return t(_p( + 'guildset:word_xp|set_response', + "For every **100** words they send, members will now be rewarded an additional **{amount}** XP." + )).format(amount=self.value) + + class UntrackedTextChannels(ListData, ChannelListSetting): + setting_id = 'untracked_text_channels' + + _display_name = _p('guildset:untracked_text_channels', "untracked_text_channels") + _desc = _p( + 'guildset:untracked_text_channels|desc', + "Channels in which Message XP will not be given." + ) + _long_desc = _p( + 'guildset:untracked_text_channels|long_desc', + "Messages sent in these channels will not count towards a member's message XP. " + "If a category is selected, then all channels under the category will also be untracked." + ) + + _default = None + _table_interface = TextTrackerData.untracked_channels + _id_column = 'guildid' + _data_column = 'channelid' + _order_column = 'channelid' + + _cache = {} + + @classmethod + @log_wrap(action='Cache Untracked Text Channels') + async def setup(cls, bot): + """ + Pre-load untracked text channels for every guild on the current shard. + """ + data: TextTrackerData = bot.db.registries['TextTrackerData'] + # TODO: Filter by joining on guild_config with last_left = NULL + # Otherwise we are also caching all the guilds we left + rows = await data.untracked_channels.select_where(THIS_SHARD) + new_cache = defaultdict(list) + count = 0 + for row in rows: + new_cache[row['guildid']].append(row['channelid']) + count += 1 + cls._cache.clear() + cls._cache.update(new_cache) + logger.info(f"Loaded {count} untracked text channels on this shard.") + + +class TextTrackerGlobalSettings(SettingGroup): + """ + Configure global XP rates for the text tracker. + """ + class XPPerPeriod(ModelData, IntegerSetting): + setting_id = 'xp_per_period' + + _display_name = _p('botset:xp_per_period', "xp_per_5min") + _desc = _p( + 'botset:xp_per_period|desc', + "How much global XP members will be given every 5 minute period they are active." + ) + _long_desc = _p( + 'botset:xp_per_period|long_desc', + "Amount of global message XP to give members " + "for each 5 minute period in which they are active (send a message). " + "Note that this XP is only given *once* per period." + ) + _default = 101 + + _model = TextTrackerData.BotConfigText + _column = TextTrackerData.BotConfigText.xp_per_period.name + + @property + def update_message(self): + t = ctx_translator.get().t + return t(_p( + 'leoset:xp_per_period|set_response', + "For every **5** minutes they are active (i.e. in which they send a message), " + "all users will now be given **{amount}** global XP." + )).format(amount=self.value) + + class WordXP(ModelData, IntegerSetting): + setting_id = 'word_xp' + + _display_name = _p('botset:word_xp', "xp_per_100words") + _desc = _p( + 'botset:word_xp|desc', + "How much global XP members will be given per hundred words they write." + ) + _long_desc = _p( + 'botset:word_xp|long_desc', + "Amount of global message XP to be given (additionally to the XP per period) for each hundred words. " + "Useful for rewarding communication." + ) + _default = 50 + + _model = TextTrackerData.BotConfigText + _column = TextTrackerData.BotConfigText.xp_per_centiword.name + + @property + def update_message(self): + t = ctx_translator.get().t + return t(_p( + 'leoset:word_xp|set_response', + "For every **100** words they send, users will now be rewarded an additional **{amount}** global XP." + )).format(amount=self.value) diff --git a/src/tracking/text/ui.py b/src/tracking/text/ui.py new file mode 100644 index 00000000..5d4ad618 --- /dev/null +++ b/src/tracking/text/ui.py @@ -0,0 +1,92 @@ +import asyncio + +import discord +from discord.ui.select import select, Select, ChannelSelect +from discord.ui.button import button, Button, ButtonStyle + +from meta import LionBot + +from utils.ui import ConfigUI, DashboardSection +from utils.lib import MessageArgs + +from .settings import TextTrackerSettings, TextTrackerGlobalSettings +from . import babel + +_p = babel._p + + +class TextTrackerConfigUI(ConfigUI): + setting_classes = ( + TextTrackerSettings.XPPerPeriod, + TextTrackerSettings.WordXP, + TextTrackerSettings.UntrackedTextChannels, + ) + + def __init__(self, bot: LionBot, + guildid: int, channelid: int, **kwargs): + self.settings = bot.get_cog('TextTrackerCog').settings + super().__init__(bot, guildid, channelid, **kwargs) + + @select( + cls=ChannelSelect, + placeholder='UNTRACKED_CHANNELS_PLACEHOLDER', + min_values=0, max_values=25 + ) + async def untracked_channels_menu(self, selection: discord.Interaction, selected): + await selection.response.defer() + setting = self.instances[3] + setting.value = selected.values + await setting.write() + + async def untracked_channels_menu_refresh(self): + t = self.bot.translator.t + self.untracked_channels_menu.placeholder = t(_p( + 'ui:text_tracker_config|menu:untracked_channels|placeholder', + "Select Untracked Channels" + )) + + async def make_message(self) -> MessageArgs: + t = self.bot.translator.t + title = t(_p( + 'ui:text_tracker_config|embed|title', + "Message Tracking Configuration Panel" + )) + embed = discord.Embed( + colour=discord.Colour.orange(), + title=title + ) + for setting in self.instances: + embed.add_field(**setting.embed_field, inline=False) + + args = MessageArgs(embed=embed) + return args + + async def reload(self): + lguild = await self.bot.core.lions.fetch_guild(self.guildid) + xp_per_period = lguild.config.get(self.settings.XPPerPeriod.setting_id) + wordxp = lguild.config.get(self.settings.WordXP.setting_id) + untracked = await self.settings.UntrackedTextChannels.get(self.guildid) + self.instances = ( + xp_per_period, wordxp, untracked + ) + + async def refresh_components(self): + await asyncio.gather( + self.edit_button_refresh(), + self.close_button_refresh(), + self.reset_button_refresh(), + self.untracked_channels_menu_refresh(), + ) + self._layout = [ + (self.untracked_channels_menu,), + (self.edit_button, self.reset_button, self.close_button) + ] + + +class TextTrackerDashboard(DashboardSection): + section_name = _p( + 'dash:text_tracking|title', + "Message XP configuration", + ) + configui = TextTrackerConfigUI + setting_classes = configui.setting_classes