rewrite: Message Tracker.
This commit is contained in:
361
src/tracking/text/cog.py
Normal file
361
src/tracking/text/cog.py
Normal file
@@ -0,0 +1,361 @@
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import time
|
||||
import datetime as dt
|
||||
from collections import defaultdict
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext, conf
|
||||
from meta.errors import UserInputError
|
||||
from meta.logger import log_wrap, logging_context
|
||||
from meta.sharding import THIS_SHARD
|
||||
from meta.app import appname
|
||||
from utils.lib import utc_now, error_embed
|
||||
|
||||
from wards import low_management
|
||||
from . import babel, logger
|
||||
from .data import TextTrackerData
|
||||
|
||||
from .session import TextSession
|
||||
from .settings import TextTrackerSettings, TextTrackerGlobalSettings
|
||||
from .ui import TextTrackerConfigUI
|
||||
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class TextTrackerCog(LionCog):
|
||||
"""
|
||||
LionCog module controlling and configuring the text tracking system.
|
||||
"""
|
||||
# Maximum number of completed sessions to batch before processing
|
||||
batchsize = conf.text_tracker.getint('batchsize')
|
||||
|
||||
# Maximum time to processing for a completed session
|
||||
batchtime = conf.text_tracker.getint('batchtime')
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(TextTrackerData())
|
||||
self.settings = TextTrackerSettings()
|
||||
self.global_settings = TextTrackerGlobalSettings()
|
||||
self.babel = babel
|
||||
|
||||
self.sessionq = asyncio.Queue(maxsize=0)
|
||||
|
||||
# Map of ongoing text sessions
|
||||
# guildid -> (userid -> TextSession)
|
||||
self.ongoing = defaultdict(dict)
|
||||
|
||||
self._consumer_task = None
|
||||
|
||||
self.untracked_channels = self.settings.UntrackedTextChannels._cache
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.XPPerPeriod)
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.WordXP)
|
||||
self.bot.core.guild_config.register_setting(self.settings.UntrackedTextChannels)
|
||||
|
||||
self.global_xp_per_period = await self.global_settings.XPPerPeriod.get(appname)
|
||||
self.global_word_xp = await self.global_settings.WordXP.get(appname)
|
||||
|
||||
leo_setting_cog = self.bot.get_cog('LeoSettings')
|
||||
leo_setting_cog.bot_setting_groups.append(self.global_settings)
|
||||
self.crossload_group(self.leo_configure_group, leo_setting_cog.leo_configure_group)
|
||||
|
||||
# Update the untracked text channel cache
|
||||
await self.settings.UntrackedTextChannels.setup(self.bot)
|
||||
|
||||
configcog = self.bot.get_cog('ConfigCog')
|
||||
if configcog is None:
|
||||
logger.critical(
|
||||
"Attempting to load the TextTrackerCog before ConfigCog! Failed to crossload configuration group."
|
||||
)
|
||||
else:
|
||||
self.crossload_group(self.configure_group, configcog.configure_group)
|
||||
|
||||
if self.bot.is_ready():
|
||||
await self.initialise()
|
||||
|
||||
async def cog_unload(self):
|
||||
if self._consumer_task is not None:
|
||||
self._consumer_task.cancel()
|
||||
|
||||
@log_wrap(stack=['Text Sessions', 'Finished'])
|
||||
async def session_handler(self, session: TextSession):
|
||||
"""
|
||||
Callback used to process a completed session.
|
||||
|
||||
Places the session into the completed queue and removes it from the session cache.
|
||||
"""
|
||||
cached = self.ongoing[session.guildid].pop(session.userid, None)
|
||||
if cached is not session:
|
||||
raise ValueError("Sync error, completed session does not match cached session!")
|
||||
logger.debug(
|
||||
"Ending text session: {session!r}".format(
|
||||
session=session
|
||||
)
|
||||
)
|
||||
self.sessionq.put_nowait(session)
|
||||
|
||||
@log_wrap(stack=['Text Sessions', 'Message Event'])
|
||||
async def _session_consumer(self):
|
||||
"""
|
||||
Process completed sessions in batches of length `batchsize`.
|
||||
"""
|
||||
# Number of sessions in the batch
|
||||
counter = 0
|
||||
batch = []
|
||||
last_time = time.monotonic()
|
||||
|
||||
closing = False
|
||||
while not closing:
|
||||
try:
|
||||
session = await self.sessionq.get()
|
||||
batch.append(session)
|
||||
counter += 1
|
||||
except asyncio.CancelledError:
|
||||
# Attempt to process the rest of the batch, then close
|
||||
closing = True
|
||||
|
||||
if counter >= self.batchsize or time.monotonic() - last_time > self.batchtime or closing:
|
||||
if batch:
|
||||
try:
|
||||
await self._process_batch(batch)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unknown exception processing batch of text sessions! Discarding and continuing."
|
||||
)
|
||||
batch = []
|
||||
counter = 0
|
||||
last_time = time.monotonic()
|
||||
|
||||
async def _process_batch(self, batch):
|
||||
"""
|
||||
Process a batch of completed text sessions.
|
||||
|
||||
Handles economy calculations.
|
||||
"""
|
||||
if not batch:
|
||||
raise ValueError("Cannot process empty batch!")
|
||||
|
||||
logger.info(
|
||||
f"Saving batch of {len(batch)} completed text sessions."
|
||||
)
|
||||
|
||||
# Batch-fetch lguilds
|
||||
lguilds = await self.bot.core.lions.fetch_guilds(*(session.guildid for session in batch))
|
||||
|
||||
# Build data
|
||||
rows = []
|
||||
for sess in batch:
|
||||
# TODO: XP and coin calculations from settings
|
||||
# Note that XP is calculated here rather than directly through the DB
|
||||
# to support both XP and economy dynamic bonuses.
|
||||
|
||||
globalxp = (
|
||||
sess.total_periods * self.global_xp_per_period.value
|
||||
+ self.global_word_xp.value * sess.total_words / 100
|
||||
)
|
||||
|
||||
lguild = lguilds[sess.guildid]
|
||||
periodxp = lguild.config.get('xp_per_period').value
|
||||
wordxp = lguild.config.get('word_xp').value
|
||||
xpcoins = lguild.config.get('coins_per_xp').value
|
||||
guildxp = (
|
||||
sess.total_periods * periodxp
|
||||
+ wordxp * sess.total_words / 100
|
||||
)
|
||||
coins = xpcoins * guildxp / 100
|
||||
rows.append((
|
||||
sess.guildid, sess.userid,
|
||||
sess.start_time, sess.duration,
|
||||
sess.total_messages, sess.total_words, sess.total_periods,
|
||||
int(guildxp), int(globalxp),
|
||||
int(coins)
|
||||
))
|
||||
|
||||
# Submit to batch data handler
|
||||
# TODO: error handling
|
||||
await self.data.TextSessions.end_sessions(self.bot.db, *rows)
|
||||
rank_cog = self.bot.get_cog('RankCog')
|
||||
if rank_cog:
|
||||
await rank_cog.on_message_session_complete(
|
||||
*((rows[0], rows[1], rows[4], rows[7]) for rows in rows)
|
||||
)
|
||||
|
||||
@LionCog.listener('on_ready')
|
||||
@log_wrap(action='Init Text Sessions')
|
||||
async def initialise(self):
|
||||
"""
|
||||
Launch the session consumer.
|
||||
"""
|
||||
if self._consumer_task and not self._consumer_task.cancelled():
|
||||
self._consumer_task.cancel()
|
||||
self._consumer_task = asyncio.create_task(self._session_consumer())
|
||||
logger.info("Launched text session consumer.")
|
||||
|
||||
@LionCog.listener('on_message')
|
||||
@log_wrap(stack=['Text Sessions', 'Message Event'])
|
||||
async def text_message_handler(self, message):
|
||||
"""
|
||||
Message event handler for the text session tracker.
|
||||
|
||||
Process the handled message through a text session,
|
||||
creating it if required.
|
||||
"""
|
||||
# Initial wards
|
||||
if message.author.bot:
|
||||
return
|
||||
if not message.guild:
|
||||
return
|
||||
# TODO: Blacklisted ward
|
||||
|
||||
guildid = message.guild.id
|
||||
channel = message.channel
|
||||
# Untracked channel ward
|
||||
untracked = self.untracked_channels.get(guildid, [])
|
||||
if channel.id in untracked or (channel.category_id and channel.category_id in untracked):
|
||||
return
|
||||
|
||||
# Identify whether a session already exists for this member
|
||||
guild_sessions = self.ongoing[guildid]
|
||||
if (session := guild_sessions.get(message.author.id, None)) is None:
|
||||
with logging_context(context=f"mid: {message.id}"):
|
||||
session = TextSession.from_message(message)
|
||||
session.on_finish(self.session_handler)
|
||||
guild_sessions[message.author.id] = session
|
||||
logger.debug(
|
||||
"Launched new text session: {session!r}".format(
|
||||
session=session
|
||||
)
|
||||
)
|
||||
session.process(message)
|
||||
|
||||
# -------- Configuration Commands --------
|
||||
@LionCog.placeholder_group
|
||||
@cmds.hybrid_group('configure', with_app_command=False)
|
||||
async def configure_group(self, ctx: LionContext):
|
||||
# Placeholder group method, not used
|
||||
pass
|
||||
|
||||
@configure_group.command(
|
||||
name=_p('cmd:configure_message_exp', "message_exp"),
|
||||
description=_p(
|
||||
'cmd:configure_message_exp|desc',
|
||||
"Configure Message Tracking & Experience"
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
xp_per_period=TextTrackerSettings.XPPerPeriod._display_name,
|
||||
word_xp=TextTrackerSettings.WordXP._display_name,
|
||||
)
|
||||
@appcmds.describe(
|
||||
xp_per_period=TextTrackerSettings.XPPerPeriod._desc,
|
||||
word_xp=TextTrackerSettings.WordXP._desc,
|
||||
)
|
||||
@cmds.check(low_management)
|
||||
async def configure_text_tracking_cmd(self, ctx: LionContext,
|
||||
xp_per_period: Optional[appcmds.Range[int, 0, 2**15]] = None,
|
||||
word_xp: Optional[appcmds.Range[int, 0, 2**15]] = None):
|
||||
"""
|
||||
Guild configuration command to view and configure the text tracker settings.
|
||||
"""
|
||||
# Standard type checking guards
|
||||
if not ctx.guild:
|
||||
return
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
# Retrieve and initialise settings
|
||||
setting_xp_period = ctx.lguild.config.get('xp_per_period')
|
||||
setting_word_xp = ctx.lguild.config.get('word_xp')
|
||||
|
||||
modified = []
|
||||
if xp_per_period is not None and setting_xp_period._data != xp_per_period:
|
||||
setting_xp_period.data = xp_per_period
|
||||
await setting_xp_period.write()
|
||||
modified.append(setting_xp_period)
|
||||
if word_xp is not None and setting_word_xp._data != word_xp:
|
||||
setting_word_xp.data = word_xp
|
||||
await setting_word_xp.write()
|
||||
modified.append(setting_word_xp)
|
||||
|
||||
# Send update ack if required
|
||||
if modified:
|
||||
desc = '\n'.join(f"{conf.emojis.tick} {setting.update_message}" for setting in modified)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
description=desc
|
||||
)
|
||||
)
|
||||
|
||||
if ctx.channel.id not in TextTrackerConfigUI._listening or not modified:
|
||||
# Display setting group UI
|
||||
configui = TextTrackerConfigUI(self.bot, ctx.guild.id, ctx.channel.id)
|
||||
await configui.run(ctx.interaction)
|
||||
await configui.wait()
|
||||
|
||||
# -------- Global Configuration Commands --------
|
||||
@LionCog.placeholder_group
|
||||
@cmds.hybrid_group('leo_configure', with_app_command=False)
|
||||
async def leo_configure_group(self, ctx: LionContext):
|
||||
# Placeholder group method, not used
|
||||
pass
|
||||
|
||||
@leo_configure_group.command(
|
||||
name=_p('cmd:leo_configure_exp_rates', "experience_rates"),
|
||||
description=_p(
|
||||
'cmd:leo_configure_exp_rates|desc',
|
||||
"Global experience rate configuration"
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
xp_per_period=TextTrackerGlobalSettings.XPPerPeriod._display_name,
|
||||
word_xp=TextTrackerGlobalSettings.WordXP._display_name,
|
||||
)
|
||||
@appcmds.describe(
|
||||
xp_per_period=TextTrackerGlobalSettings.XPPerPeriod._desc,
|
||||
word_xp=TextTrackerGlobalSettings.WordXP._desc,
|
||||
)
|
||||
async def leo_configure_text_tracking_cmd(self, ctx: LionContext,
|
||||
xp_per_period: Optional[appcmds.Range[int, 0, 2**15]] = None,
|
||||
word_xp: Optional[appcmds.Range[int, 0, 2**15]] = None):
|
||||
"""
|
||||
Global configuration panel for text tracking global XP.
|
||||
"""
|
||||
setting_xp_period = self.global_xp_per_period
|
||||
setting_word_xp = self.global_word_xp
|
||||
|
||||
modified = []
|
||||
if word_xp is not None and word_xp != setting_word_xp._data:
|
||||
setting_word_xp.value = word_xp
|
||||
await setting_word_xp.write()
|
||||
modified.append(setting_word_xp)
|
||||
if xp_per_period is not None and xp_per_period != setting_xp_period._data:
|
||||
setting_xp_period.value = xp_per_period
|
||||
await setting_xp_period.write()
|
||||
modified.append(setting_xp_period)
|
||||
|
||||
if modified:
|
||||
desc = '\n'.join(f"{conf.emojis.tick} {setting.update_message}" for setting in modified)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
description=desc
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
title="Configure Global XP"
|
||||
)
|
||||
embed.add_field(**setting_xp_period.embed_field, inline=False)
|
||||
embed.add_field(**setting_word_xp.embed_field, inline=False)
|
||||
await ctx.reply(embed=embed)
|
||||
Reference in New Issue
Block a user