Merge pull request #58 from StudyLions/rewrite

This commit is contained in:
Interitio
2023-10-08 12:34:54 +03:00
committed by GitHub
39 changed files with 703 additions and 531 deletions

View File

@@ -24,9 +24,9 @@ for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
class AnalyticsServer: class AnalyticsServer:
# TODO: Move these to the config # TODO: Move these to the config
# How often to request snapshots # How often to request snapshots
snap_period = 120 snap_period = 900
# How soon after a snapshot failure (e.g. not all shards online) to retry # How soon after a snapshot failure (e.g. not all shards online) to retry
snap_retry_period = 10 snap_retry_period = 60
def __init__(self) -> None: def __init__(self) -> None:
self.db = Database(conf.data['args']) self.db = Database(conf.data['args'])

View File

@@ -241,7 +241,7 @@ class BabelCog(LionCog):
matching = {item for item in formatted if partial in item[1] or partial in item[0]} matching = {item for item in formatted if partial in item[1] or partial in item[0]}
if matching: if matching:
choices = [ choices = [
appcmds.Choice(name=localestr, value=locale) appcmds.Choice(name=localestr[:100], value=locale)
for locale, localestr in matching for locale, localestr in matching
] ]
else: else:
@@ -250,7 +250,7 @@ class BabelCog(LionCog):
name=t(_p( name=t(_p(
'acmpl:language|no_match', 'acmpl:language|no_match',
"No supported languages matching {partial}" "No supported languages matching {partial}"
)).format(partial=partial), )).format(partial=partial)[:100],
value=partial value=partial
) )
] ]

View File

@@ -1,9 +1,11 @@
import gettext from typing import Optional
import logging import logging
from contextvars import ContextVar from contextvars import ContextVar
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
import gettext
from discord.app_commands import Translator, locale_str from discord.app_commands import Translator, locale_str
from discord.enums import Locale from discord.enums import Locale
@@ -70,7 +72,8 @@ class LeoBabel(Translator):
async def unload(self): async def unload(self):
self.translators.clear() self.translators.clear()
def get_translator(self, locale, domain): def get_translator(self, locale: Optional[str], domain):
locale = locale or SOURCE_LOCALE
locale = locale.replace('-', '_') if locale else None locale = locale.replace('-', '_') if locale else None
if locale == SOURCE_LOCALE: if locale == SOURCE_LOCALE:
translator = null translator = null

View File

@@ -2,6 +2,7 @@ from typing import Optional
import datetime as dt import datetime as dt
import pytz import pytz
import discord import discord
import logging
from meta import LionBot from meta import LionBot
from utils.lib import Timezoned from utils.lib import Timezoned
@@ -13,6 +14,9 @@ from .lion_user import LionUser
from .lion_guild import LionGuild from .lion_guild import LionGuild
logger = logging.getLogger(__name__)
class MemberConfig(ModelConfig): class MemberConfig(ModelConfig):
settings = SettingDotDict() settings = SettingDotDict()
_model_settings = set() _model_settings = set()
@@ -103,12 +107,16 @@ class LionMember(Timezoned):
async def remove_role(self, role: discord.Role): async def remove_role(self, role: discord.Role):
member = await self.fetch_member() member = await self.fetch_member()
if member is not None and role in member.roles: if member is not None:
try: try:
await member.remove_roles(role) await member.remove_roles(role)
except discord.HTTPException: except discord.HTTPException as e:
# TODO: Logging, audit logging # TODO: Logging, audit logging
pass logger.warning(
"Lion role removal failed for "
f"<uid: {member.id}>, <gid: {member.guild.id}>, <rid: {role.id}>. "
f"Error: {repr(e)}",
)
else: else:
# Remove the role from persistent role storage # Remove the role from persistent role storage
cog = self.bot.get_cog('MemberAdminCog') cog = self.bot.get_cog('MemberAdminCog')

View File

@@ -46,7 +46,7 @@ class LionBot(Bot):
# self.appdata = appdata # self.appdata = appdata
self.config = config self.config = config
self.app_ipc = app_ipc self.app_ipc = app_ipc
self.core: Optional['CoreCog'] = None self.core: 'CoreCog' = None
self.translator = translator self.translator = translator
self.system_monitor = SystemMonitor() self.system_monitor = SystemMonitor()

View File

@@ -38,8 +38,9 @@ class LionTree(CommandTree):
await self.error_reply(interaction, embed) await self.error_reply(interaction, embed)
except Exception: except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'}) logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
embed = self.bugsplat(interaction, error) if interaction.type is not InteractionType.autocomplete:
await self.error_reply(interaction, embed) embed = self.bugsplat(interaction, error)
await self.error_reply(interaction, embed)
async def error_reply(self, interaction, embed): async def error_reply(self, interaction, embed):
if not interaction.is_expired(): if not interaction.is_expired():
@@ -144,7 +145,10 @@ class LionTree(CommandTree):
raise AppCommandError( raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.' 'This should not happen, but there is no focused element. This is a Discord bug.'
) )
await command._invoke_autocomplete(interaction, focused, namespace) try:
await command._invoke_autocomplete(interaction, focused, namespace)
except Exception as e:
await self.on_error(interaction, e)
return return
set_logging_context(action=f"Run {command.qualified_name}") set_logging_context(action=f"Run {command.qualified_name}")

View File

@@ -185,23 +185,28 @@ class GuildDashboard(BasePager):
# ----- UI Control ----- # ----- UI Control -----
async def reload(self, *args): async def reload(self, *args):
self._cached_pages.clear() self._cached_pages.clear()
if not self._original.is_expired(): if self._original and not self._original.is_expired():
await self.redraw() await self.redraw()
else:
await self.close()
async def refresh(self): async def refresh(self):
await super().refresh() await super().refresh()
await self.config_menu_refresh() await self.config_menu_refresh()
self._layout = [ self.set_layout(
(self.config_menu,), (self.config_menu,),
(self.prev_page_button, self.next_page_button) (self.prev_page_button, self.next_page_button)
] )
async def redraw(self, *args): async def redraw(self, *args):
await self.refresh() await self.refresh()
await self._original.edit_original_response( if self._original and not self._original.is_expired():
**self.current_page.edit_args, await self._original.edit_original_response(
view=self **self.current_page.edit_args,
) view=self
)
else:
await self.close()
async def run(self, interaction: discord.Interaction): async def run(self, interaction: discord.Interaction):
await self.refresh() await self.refresh()

View File

@@ -227,7 +227,8 @@ class MemberAdminCog(LionCog):
logger.info(f"Cleared persisting roles for guild <gid:{guild.id}> because we left the guild.") logger.info(f"Cleared persisting roles for guild <gid:{guild.id}> because we left the guild.")
@LionCog.listener('on_guildset_role_persistence') @LionCog.listener('on_guildset_role_persistence')
async def clear_stored_roles(self, guildid, data): async def clear_stored_roles(self, guildid, setting: MemberAdminSettings.RolePersistence):
data = setting.data
if data is False: if data is False:
await self.data.past_roles.delete_where(guildid=guildid) await self.data.past_roles.delete_where(guildid=guildid)
logger.info( logger.info(

View File

@@ -73,7 +73,7 @@ class TimerCog(LionCog):
launched=sum(1 for timer in timers if timer._run_task and not timer._run_task.done()), launched=sum(1 for timer in timers if timer._run_task and not timer._run_task.done()),
looping=sum(1 for timer in timers if timer._loop_task and not timer._loop_task.done()), looping=sum(1 for timer in timers if timer._loop_task and not timer._loop_task.done()),
locked=sum(1 for timer in timers if timer._lock.locked()), locked=sum(1 for timer in timers if timer._lock.locked()),
voice_locked=sum(1 for timer in timers if timer._voice_update_lock.locked()), voice_locked=sum(1 for timer in timers if timer.voice_lock.locked()),
) )
if not self.ready: if not self.ready:
level = StatusLevel.STARTING level = StatusLevel.STARTING
@@ -343,7 +343,7 @@ class TimerCog(LionCog):
@LionCog.listener('on_guildset_pomodoro_channel') @LionCog.listener('on_guildset_pomodoro_channel')
@log_wrap(action='Update Pomodoro Channels') @log_wrap(action='Update Pomodoro Channels')
async def _update_pomodoro_channels(self, guildid: int, data: Optional[int]): async def _update_pomodoro_channels(self, guildid: int, setting: TimerSettings.PomodoroChannel):
""" """
Request a send_status for all guild timers which need to move channel. Request a send_status for all guild timers which need to move channel.
""" """

View File

@@ -136,6 +136,10 @@ class Timer:
channel = self.channel channel = self.channel
return channel return channel
@property
def voice_lock(self):
return self.lguild.voice_lock
async def get_notification_webhook(self) -> Optional[discord.Webhook]: async def get_notification_webhook(self) -> Optional[discord.Webhook]:
channel = self.notification_channel channel = self.notification_channel
if channel: if channel:
@@ -477,14 +481,13 @@ class Timer:
async with self.lguild.voice_lock: async with self.lguild.voice_lock:
try: try:
if self.guild.voice_client: if self.guild.voice_client:
print("Disconnecting")
await self.guild.voice_client.disconnect(force=True) await self.guild.voice_client.disconnect(force=True)
print("Disconnected")
alert_file = focus_alert_path if stage.focused else break_alert_path alert_file = focus_alert_path if stage.focused else break_alert_path
try: try:
print("Connecting") voice_client = await asyncio.wait_for(
voice_client = await self.channel.connect(timeout=60, reconnect=False) self.channel.connect(timeout=30, reconnect=False),
print("Connected") timeout=60
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"Timed out while connecting to voice channel in timer {self!r}") logger.warning(f"Timed out while connecting to voice channel in timer {self!r}")
return return
@@ -511,13 +514,18 @@ class Timer:
_, pending = await asyncio.wait([sleep_task, wait_task], return_when=asyncio.FIRST_COMPLETED) _, pending = await asyncio.wait([sleep_task, wait_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending: for task in pending:
task.cancel() task.cancel()
except asyncio.TimeoutError:
if self.guild and self.guild.voice_client: logger.warning(
await self.guild.voice_client.disconnect(force=True) f"Timed out while sending voice alert for timer {self!r}",
exc_info=True
)
except Exception: except Exception:
logger.exception( logger.exception(
f"Exception occurred while playing voice alert for timer {self!r}" f"Exception occurred while playing voice alert for timer {self!r}"
) )
finally:
if self.guild and self.guild.voice_client:
await self.guild.voice_client.disconnect(force=True)
def stageline(self, stage: Stage): def stageline(self, stage: Stage):
t = self.bot.translator.t t = self.bot.translator.t
@@ -772,7 +780,7 @@ class Timer:
logger.info(f"Timer {self!r} has stopped. Auto restart is {'on' if auto_restart else 'off'}") logger.info(f"Timer {self!r} has stopped. Auto restart is {'on' if auto_restart else 'off'}")
@log_wrap(action="Destroy Timer") @log_wrap(action="Destroy Timer")
async def destroy(self, reason: str = None): async def destroy(self, reason: Optional[str] = None):
""" """
Deconstructs the timer, stopping all tasks. Deconstructs the timer, stopping all tasks.
""" """

View File

@@ -6,7 +6,7 @@ from discord.ui.select import select, Select, SelectOption, RoleSelect
from discord.ui.button import button, Button, ButtonStyle from discord.ui.button import button, Button, ButtonStyle
from meta import conf, LionBot from meta import conf, LionBot
from meta.errors import ResponseTimedOut from meta.errors import ResponseTimedOut, SafeCancellation
from core.data import RankType from core.data import RankType
from data import ORDER from data import ORDER
@@ -16,7 +16,7 @@ from wards import equippable_role
from babel.translator import ctx_translator from babel.translator import ctx_translator
from .. import babel, logger from .. import babel, logger
from ..data import AnyRankData from ..data import AnyRankData, RankData
from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value
from .editor import RankEditor from .editor import RankEditor
from .preview import RankPreviewUI from .preview import RankPreviewUI
@@ -101,6 +101,7 @@ class RankOverviewUI(MessageUI):
Refresh the current ranks, Refresh the current ranks,
ensuring that all members have the correct rank. ensuring that all members have the correct rank.
""" """
await press.response.defer(thinking=True)
async with self.cog.ranklock(self.guild.id): async with self.cog.ranklock(self.guild.id):
await self.cog.interactive_rank_refresh(press, self.guild) await self.cog.interactive_rank_refresh(press, self.guild)
@@ -156,11 +157,21 @@ class RankOverviewUI(MessageUI):
Errors if the client does not have permission to create roles. Errors if the client does not have permission to create roles.
""" """
t = self.bot.translator.t
if not self.guild.me.guild_permissions.manage_roles:
raise SafeCancellation(t(_p(
'ui:rank_overview|button:create|error:my_permissions',
"I lack the 'Manage Roles' permission required to create rank roles!"
)))
async def _create_callback(rank, submit: discord.Interaction): async def _create_callback(rank, submit: discord.Interaction):
await submit.response.send_message( await submit.response.send_message(
embed=discord.Embed( embed=discord.Embed(
colour=discord.Colour.brand_green(), colour=discord.Colour.brand_green(),
description="Rank Created!" description=t(_p(
'ui:rank_overview|button:create|success',
"Created a new rank {role}"
)).format(role=f"<@&{rank.roleid}>")
), ),
ephemeral=True ephemeral=True
) )

View File

@@ -447,7 +447,7 @@ class Reminders(LionCog):
)) ))
value = 'None' value = 'None'
choices = [ choices = [
appcmds.Choice(name=name, value=value) appcmds.Choice(name=name[:100], value=value)
] ]
else: else:
# Build list of reminder strings # Build list of reminder strings
@@ -463,7 +463,7 @@ class Reminders(LionCog):
# Build list of valid choices # Build list of valid choices
choices = [ choices = [
appcmds.Choice( appcmds.Choice(
name=string[0], name=string[0][:100],
value=f"rid:{string[1].reminderid}" value=f"rid:{string[1].reminderid}"
) )
for string in matches for string in matches
@@ -474,7 +474,7 @@ class Reminders(LionCog):
name=t(_p( name=t(_p(
'cmd:reminders_cancel|acmpl:reminder|error:no_matches', 'cmd:reminders_cancel|acmpl:reminder|error:no_matches',
"You do not have any reminders matching \"{partial}\"" "You do not have any reminders matching \"{partial}\""
)).format(partial=partial), )).format(partial=partial)[:100],
value=partial value=partial
) )
] ]
@@ -562,7 +562,7 @@ class Reminders(LionCog):
name=t(_p( name=t(_p(
'cmd:remindme_at|acmpl:time|error:parse', 'cmd:remindme_at|acmpl:time|error:parse',
"Cannot parse \"{partial}\" as a time. Try the format HH:MM or YYYY-MM-DD HH:MM" "Cannot parse \"{partial}\" as a time. Try the format HH:MM or YYYY-MM-DD HH:MM"
)).format(partial=partial), )).format(partial=partial)[:100],
value=partial value=partial
) )
return [choice] return [choice]

View File

@@ -14,6 +14,7 @@ from meta import LionCog, LionBot, LionContext
from meta.logger import log_wrap from meta.logger import log_wrap
from meta.errors import ResponseTimedOut, UserInputError, UserCancelled, SafeCancellation from meta.errors import ResponseTimedOut, UserInputError, UserCancelled, SafeCancellation
from meta.sharding import THIS_SHARD from meta.sharding import THIS_SHARD
from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel
from utils.lib import utc_now, error_embed from utils.lib import utc_now, error_embed
from utils.ui import Confirm, ChoicedEnum, Transformed, AButton, AsComponents from utils.ui import Confirm, ChoicedEnum, Transformed, AButton, AsComponents
from utils.transformers import DurationTransformer from utils.transformers import DurationTransformer
@@ -142,6 +143,9 @@ class RoleMenuCog(LionCog):
def __init__(self, bot: LionBot): def __init__(self, bot: LionBot):
self.bot = bot self.bot = bot
self.data = bot.db.load_registry(RoleMenuData()) self.data = bot.db.load_registry(RoleMenuData())
self.monitor = ComponentMonitor('RoleMenus', self._monitor)
self.ready = asyncio.Event()
# Menu caches # Menu caches
self.live_menus = RoleMenu.attached_menus # guildid -> messageid -> menuid self.live_menus = RoleMenu.attached_menus # guildid -> messageid -> menuid
@@ -149,11 +153,42 @@ class RoleMenuCog(LionCog):
# Expiry manage # Expiry manage
self.expiry_monitor = ExpiryMonitor(executor=self._expire) self.expiry_monitor = ExpiryMonitor(executor=self._expire)
async def _monitor(self):
state = (
"<"
"RoleMenus"
" ready={ready}"
" cached={cached}"
" views={views}"
" live={live}"
" expiry={expiry}"
">"
)
data = dict(
ready=self.ready.is_set(),
live=sum(len(gmenus) for gmenus in self.live_menus.values()),
expiry=repr(self.expiry_monitor),
cached=len(RoleMenu._menus),
views=len(RoleMenu.menu_views),
)
if not self.ready.is_set():
level = StatusLevel.STARTING
info = f"(STARTING) Not initialised. {state}"
elif not self.expiry_monitor._monitor_task:
level = StatusLevel.ERRORED
info = f"(ERRORED) Expiry monitor not running. {state}"
else:
level = StatusLevel.OKAY
info = f"(OK) RoleMenu loaded and listening. {state}"
return ComponentStatus(level, info, info, data)
# ----- Initialisation ----- # ----- Initialisation -----
async def cog_load(self): async def cog_load(self):
self.bot.system_monitor.add_component(self.monitor)
await self.data.init() await self.data.init()
self.bot.tree.add_command(rolemenu_ctxcmd) self.bot.tree.add_command(rolemenu_ctxcmd, override=True)
if self.bot.is_ready(): if self.bot.is_ready():
await self.initialise() await self.initialise()
@@ -164,17 +199,28 @@ class RoleMenuCog(LionCog):
self.live_menus.clear() self.live_menus.clear()
if self.expiry_monitor._monitor_task: if self.expiry_monitor._monitor_task:
self.expiry_monitor._monitor_task.cancel() self.expiry_monitor._monitor_task.cancel()
self.bot.tree.remove_command(rolemenu_ctxcmd)
@LionCog.listener('on_ready') @LionCog.listener('on_ready')
@log_wrap(action="Initialise Role Menus") @log_wrap(action="Initialise Role Menus")
async def initialise(self): async def initialise(self):
self.ready.clear()
# Clean up live menu tasks
for menu in list(RoleMenu._menus.values()):
menu.detach()
self.live_menus.clear()
if self.expiry_monitor._monitor_task:
self.expiry_monitor._monitor_task.cancel()
# Start monitor
self.expiry_monitor = ExpiryMonitor(executor=self._expire) self.expiry_monitor = ExpiryMonitor(executor=self._expire)
self.expiry_monitor.start() self.expiry_monitor.start()
# Load guilds
guildids = [guild.id for guild in self.bot.guilds] guildids = [guild.id for guild in self.bot.guilds]
if guildids: if guildids:
await self._initialise_guilds(*guildids) await self._initialise_guilds(*guildids)
self.ready.set()
async def _initialise_guilds(self, *guildids): async def _initialise_guilds(self, *guildids):
""" """
@@ -262,7 +308,7 @@ class RoleMenuCog(LionCog):
If the bot is no longer in the server, ignores the expiry. If the bot is no longer in the server, ignores the expiry.
If the member is no longer in the server, removes the role from persisted roles, if applicable. If the member is no longer in the server, removes the role from persisted roles, if applicable.
""" """
logger.debug(f"Expiring RoleMenu equipped role {equipid}") logger.info(f"Expiring RoleMenu equipped role {equipid}")
rows = await self.data.RoleMenuHistory.fetch_expiring_where(equipid=equipid) rows = await self.data.RoleMenuHistory.fetch_expiring_where(equipid=equipid)
if rows: if rows:
equip_row = rows[0] equip_row = rows[0]
@@ -277,6 +323,7 @@ class RoleMenuCog(LionCog):
await equip_row.update(removed_at=now) await equip_row.update(removed_at=now)
else: else:
# equipid is no longer valid or is not expiring # equipid is no longer valid or is not expiring
logger.info(f"RoleMenu equipped role {equipid} is no longer valid or is not expiring.")
pass pass
# ----- Private Utils ----- # ----- Private Utils -----
@@ -487,7 +534,7 @@ class RoleMenuCog(LionCog):
choice_name = menu.data.name choice_name = menu.data.name
choice_value = f"menuid:{menu.data.menuid}" choice_value = f"menuid:{menu.data.menuid}"
choices.append( choices.append(
appcmds.Choice(name=choice_name, value=choice_value) appcmds.Choice(name=choice_name[:100], value=choice_value)
) )
if not choices: if not choices:
@@ -498,7 +545,7 @@ class RoleMenuCog(LionCog):
)).format(partial=partial) )).format(partial=partial)
choice_value = partial choice_value = partial
choice = appcmds.Choice( choice = appcmds.Choice(
name=choice_name, value=choice_value name=choice_name[:100], value=choice_value
) )
choices.append(choice) choices.append(choice)
@@ -522,7 +569,7 @@ class RoleMenuCog(LionCog):
"Please select a menu first" "Please select a menu first"
)) ))
choice_value = partial choice_value = partial
choices = [appcmds.Choice(name=choice_name, value=choice_value)] choices = [appcmds.Choice(name=choice_name[:100], value=choice_value)]
else: else:
# Resolve the menu name # Resolve the menu name
menu: RoleMenu menu: RoleMenu
@@ -544,7 +591,7 @@ class RoleMenuCog(LionCog):
name=t(_p( name=t(_p(
'acmpl:menuroles|choice:invalid_menu|name', 'acmpl:menuroles|choice:invalid_menu|name',
"Menu '{name}' does not exist!" "Menu '{name}' does not exist!"
)).format(name=menu_name), )).format(name=menu_name)[:100],
value=partial value=partial
) )
choices = [choice] choices = [choice]
@@ -564,7 +611,7 @@ class RoleMenuCog(LionCog):
else: else:
name = mrole.data.label name = mrole.data.label
choice = appcmds.Choice( choice = appcmds.Choice(
name=name, name=name[:100],
value=f"<@&{mrole.data.roleid}>" value=f"<@&{mrole.data.roleid}>"
) )
choices.append(choice) choices.append(choice)
@@ -573,7 +620,7 @@ class RoleMenuCog(LionCog):
name=t(_p( name=t(_p(
'acmpl:menuroles|choice:no_matching|name', 'acmpl:menuroles|choice:no_matching|name',
"No roles in this menu matching '{partial}'" "No roles in this menu matching '{partial}'"
)).format(partial=partial), )).format(partial=partial)[:100],
value=partial value=partial
) )
return choices[:25] return choices[:25]

View File

@@ -173,14 +173,15 @@ class RoomCog(LionCog):
# Setting event handlers # Setting event handlers
@LionCog.listener('on_guildset_rooms_category') @LionCog.listener('on_guildset_rooms_category')
@log_wrap(action='Update Rooms Category') @log_wrap(action='Update Rooms Category')
async def _update_rooms_category(self, guildid: int, data: Optional[int]): async def _update_rooms_category(self, guildid: int, setting: RoomSettings.Category):
""" """
Move all active private channels to the new category. Move all active private channels to the new category.
This shouldn't affect the channel function at all. This shouldn't affect the channel function at all.
""" """
data = setting.data
guild = self.bot.get_guild(guildid) guild = self.bot.get_guild(guildid)
new_category = guild.get_channel(data) if guild else None new_category = guild.get_channel(data) if guild and data else None
if new_category: if new_category:
tasks = [] tasks = []
for room in list(self._room_cache[guildid].values()): for room in list(self._room_cache[guildid].values()):
@@ -196,10 +197,11 @@ class RoomCog(LionCog):
@LionCog.listener('on_guildset_rooms_visible') @LionCog.listener('on_guildset_rooms_visible')
@log_wrap(action='Update Rooms Visibility') @log_wrap(action='Update Rooms Visibility')
async def _update_rooms_visibility(self, guildid: int, data: bool): async def _update_rooms_visibility(self, guildid: int, setting: RoomSettings.Visible):
""" """
Update the everyone override on each room to reflect the new setting. Update the everyone override on each room to reflect the new setting.
""" """
data = setting.data
tasks = [] tasks = []
for room in list(self._room_cache[guildid].values()): for room in list(self._room_cache[guildid].values()):
if room.channel: if room.channel:

View File

@@ -904,10 +904,10 @@ class ScheduleCog(LionCog):
if not interaction.guild or not isinstance(interaction.user, discord.Member): if not interaction.guild or not isinstance(interaction.user, discord.Member):
choice = appcmds.Choice( choice = appcmds.Choice(
name=_p( name=t(_p(
'cmd:schedule|acmpl:book|error:not_in_guild', 'cmd:schedule|acmpl:book|error:not_in_guild',
"You need to be in a server to book sessions!" "You need to be in a server to book sessions!"
), ))[:100],
value='None' value='None'
) )
choices = [choice] choices = [choice]
@@ -917,10 +917,10 @@ class ScheduleCog(LionCog):
blacklist_role = (await self.settings.BlacklistRole.get(interaction.guild.id)).value blacklist_role = (await self.settings.BlacklistRole.get(interaction.guild.id)).value
if blacklist_role and blacklist_role in member.roles: if blacklist_role and blacklist_role in member.roles:
choice = appcmds.Choice( choice = appcmds.Choice(
name=_p( name=t(_p(
'cmd:schedule|acmpl:book|error:blacklisted', 'cmd:schedule|acmpl:book|error:blacklisted',
"Cannot Book -- Blacklisted" "Cannot Book -- Blacklisted"
), ))[:100],
value='None' value='None'
) )
choices = [choice] choices = [choice]
@@ -947,7 +947,7 @@ class ScheduleCog(LionCog):
) )
choices.append( choices.append(
appcmds.Choice( appcmds.Choice(
name=tzstring, value='None', name=tzstring[:100], value='None',
) )
) )
@@ -968,7 +968,7 @@ class ScheduleCog(LionCog):
if partial.lower() in name.lower(): if partial.lower() in name.lower():
choices.append( choices.append(
appcmds.Choice( appcmds.Choice(
name=name, name=name[:100],
value=str(slotid) value=str(slotid)
) )
) )
@@ -978,11 +978,11 @@ class ScheduleCog(LionCog):
name=t(_p( name=t(_p(
"cmd:schedule|acmpl:book|no_matching", "cmd:schedule|acmpl:book|no_matching",
"No bookable sessions matching '{partial}'" "No bookable sessions matching '{partial}'"
)).format(partial=partial[:25]), )).format(partial=partial[:25])[:100],
value=partial value=partial
) )
) )
return choices return choices[:25]
@schedule_cmd.autocomplete('cancel') @schedule_cmd.autocomplete('cancel')
async def schedule_cmd_cancel_acmpl(self, interaction: discord.Interaction, partial: str): async def schedule_cmd_cancel_acmpl(self, interaction: discord.Interaction, partial: str):
@@ -998,10 +998,10 @@ class ScheduleCog(LionCog):
can_cancel = list(slotid for slotid in schedule if slotid > minid) can_cancel = list(slotid for slotid in schedule if slotid > minid)
if not can_cancel: if not can_cancel:
choice = appcmds.Choice( choice = appcmds.Choice(
name=_p( name=t(_p(
'cmd:schedule|acmpl:cancel|error:empty_schedule', 'cmd:schedule|acmpl:cancel|error:empty_schedule',
"You do not have any upcoming sessions to cancel!" "You do not have any upcoming sessions to cancel!"
), ))[:100],
value='None' value='None'
) )
choices.append(choice) choices.append(choice)
@@ -1025,7 +1025,7 @@ class ScheduleCog(LionCog):
if partial.lower() in name.lower(): if partial.lower() in name.lower():
choices.append( choices.append(
appcmds.Choice( appcmds.Choice(
name=name, name=name[:100],
value=str(slotid) value=str(slotid)
) )
) )
@@ -1034,7 +1034,7 @@ class ScheduleCog(LionCog):
name=t(_p( name=t(_p(
'cmd:schedule|acmpl:cancel|error:no_matching', 'cmd:schedule|acmpl:cancel|error:no_matching',
"No cancellable sessions matching '{partial}'" "No cancellable sessions matching '{partial}'"
)).format(partial=partial[:25]), )).format(partial=partial[:25])[:100],
value='None' value='None'
) )
choices.append(choice) choices.append(choice)

View File

@@ -442,7 +442,7 @@ class ScheduledSession:
'session|notify|dm|join_line:channels', 'session|notify|dm|join_line:channels',
"Please attend your session by joining one of the following:" "Please attend your session by joining one of the following:"
)) ))
join_line = '\n'.join(join_line, *(channel.mention for channel in valid[:20])) join_line = '\n'.join((join_line, *(channel.mention for channel in valid[:20])))
if len(valid) > 20: if len(valid) > 20:
join_line += '\n...' join_line += '\n...'

View File

@@ -446,7 +446,7 @@ class ColourShopping(ShopCog):
), ),
ephemeral=True ephemeral=True
) )
await logger.warning( logger.warning(
"Unexpected Discord exception occurred while creating a colour role.", "Unexpected Discord exception occurred while creating a colour role.",
exc_info=True exc_info=True
) )
@@ -469,8 +469,13 @@ class ColourShopping(ShopCog):
# Due to the imprecise nature of Discord role ordering, this may fail. # Due to the imprecise nature of Discord role ordering, this may fail.
try: try:
role = await role.edit(position=position) role = await role.edit(position=position)
except discord.Forbidden: except discord.HTTPException as e:
position = 0 if e.code == 50013 or e.status == 403:
# Forbidden case
# But Discord sends its 'Missing Permissions' with a 400 code for position issues
position = 0
else:
raise
# Now that the role is set up, add it to data # Now that the role is set up, add it to data
item = await self.data.ShopItem.create( item = await self.data.ShopItem.create(
@@ -1090,7 +1095,7 @@ class ColourShopping(ShopCog):
for i, item in enumerate(items, start=1) for i, item in enumerate(items, start=1)
] ]
options = [option for option in options if partial.lower() in option[1].lower()] options = [option for option in options if partial.lower() in option[1].lower()]
return [appcmds.Choice(name=option[1], value=option[0]) for option in options] return [appcmds.Choice(name=option[1][:100], value=option[0]) for option in options]
class ColourStore(Store): class ColourStore(Store):

View File

@@ -122,7 +122,7 @@ class StatsData(Registry):
"SELECT study_time_between(%s, %s, %s, %s)", "SELECT study_time_between(%s, %s, %s, %s)",
(guildid, userid, _start, _end) (guildid, userid, _start, _end)
) )
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone())[0] or 0
@classmethod @classmethod
@log_wrap(action='study_times_between') @log_wrap(action='study_times_between')
@@ -162,11 +162,11 @@ class StatsData(Registry):
"SELECT study_time_since(%s, %s, %s)", "SELECT study_time_since(%s, %s, %s)",
(guildid, userid, _start) (guildid, userid, _start)
) )
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone())[0] or 0
@classmethod @classmethod
@log_wrap(action='study_times_since') @log_wrap(action='study_times_since')
async def study_times_since(cls, guildid: Optional[int], userid: int, *starts) -> int: async def study_times_since(cls, guildid: Optional[int], userid: int, *starts) -> list[int]:
if len(starts) < 1: if len(starts) < 1:
raise ValueError('No starting points given!') raise ValueError('No starting points given!')
@@ -251,7 +251,7 @@ class StatsData(Registry):
return leaderboard return leaderboard
@classmethod @classmethod
@log_wrap('leaderboard_all') @log_wrap(action='leaderboard_all')
async def leaderboard_all(cls, guildid: int): async def leaderboard_all(cls, guildid: int):
""" """
Return the all-time voice totals for the given guild. Return the all-time voice totals for the given guild.

View File

@@ -41,7 +41,7 @@ class StatsUI(LeoUI):
""" """
ID of guild to render stats for, or None if global. ID of guild to render stats for, or None if global.
""" """
return self.guild.id if not self._showing_global else None return self.guild.id if self.guild and not self._showing_global else None
@property @property
def userid(self) -> int: def userid(self) -> int:
@@ -67,7 +67,8 @@ class StatsUI(LeoUI):
Delete the output message and close the UI. Delete the output message and close the UI.
""" """
await press.response.defer() await press.response.defer()
await self._original.delete_original_response() if self._original and not self._original.is_expired():
await self._original.delete_original_response()
self._original = None self._original = None
await self.close() await self.close()
@@ -93,7 +94,10 @@ class StatsUI(LeoUI):
args = await self.make_message() args = await self.make_message()
if thinking is not None and not thinking.is_expired() and thinking.response.is_done(): if thinking is not None and not thinking.is_expired() and thinking.response.is_done():
asyncio.create_task(thinking.delete_original_response()) asyncio.create_task(thinking.delete_original_response())
await self._original.edit_original_response(**args.edit_args, view=self) if self._original and not self._original.is_expired():
await self._original.edit_original_response(**args.edit_args, view=self)
else:
await self.close()
async def refresh(self, thinking: Optional[discord.Interaction] = None): async def refresh(self, thinking: Optional[discord.Interaction] = None):
""" """

View File

@@ -41,6 +41,7 @@ class StatType(IntEnum):
class LeaderboardUI(StatsUI): class LeaderboardUI(StatsUI):
page_size = 10 page_size = 10
guildid: int
def __init__(self, bot, user, guild, **kwargs): def __init__(self, bot, user, guild, **kwargs):
super().__init__(bot, user, guild, **kwargs) super().__init__(bot, user, guild, **kwargs)
@@ -199,6 +200,9 @@ class LeaderboardUI(StatsUI):
mode = CardMode.TEXT mode = CardMode.TEXT
elif self.stat_type is StatType.ANKI: elif self.stat_type is StatType.ANKI:
mode = CardMode.ANKI mode = CardMode.ANKI
else:
raise ValueError
card = await get_leaderboard_card( card = await get_leaderboard_card(
self.bot, self.userid, self.guildid, self.bot, self.userid, self.guildid,
mode, mode,

View File

@@ -166,7 +166,7 @@ class ProfileUI(StatsUI):
t = self.bot.translator.t t = self.bot.translator.t
data: StatsData = self.bot.get_cog('StatsCog').data data: StatsData = self.bot.get_cog('StatsCog').data
tags = await data.ProfileTag.fetch_tags(self.guildid, self.userid) tags = await data.ProfileTag.fetch_tags(self.guild.id, self.userid)
modal = ProfileEditor() modal = ProfileEditor()
modal.editor.default = '\n'.join(tags) modal.editor.default = '\n'.join(tags)
@@ -177,7 +177,7 @@ class ProfileUI(StatsUI):
await interaction.response.defer(thinking=True, ephemeral=True) await interaction.response.defer(thinking=True, ephemeral=True)
# Set the new tags and refresh # Set the new tags and refresh
await data.ProfileTag.set_tags(self.guildid, self.userid, new_tags) await data.ProfileTag.set_tags(self.guild.id, self.userid, new_tags)
if self._original is not None: if self._original is not None:
self._profile_card = None self._profile_card = None
await self.refresh(thinking=interaction) await self.refresh(thinking=interaction)
@@ -310,7 +310,7 @@ class ProfileUI(StatsUI):
""" """
Create and render the XP and stats cards. Create and render the XP and stats cards.
""" """
card = await get_profile_card(self.bot, self.userid, self.guildid) card = await get_profile_card(self.bot, self.userid, self.guild.id)
if card: if card:
await card.render() await card.render()
self._profile_card = card self._profile_card = card

View File

@@ -329,7 +329,7 @@ class Exec(LionCog):
results = [ results = [
appcmd.Choice(name=f"No peers found matching {partial}", value=partial) appcmd.Choice(name=f"No peers found matching {partial}", value=partial)
] ]
return results return results[:25]
async_cmd.autocomplete('target')(_peer_acmpl) async_cmd.autocomplete('target')(_peer_acmpl)

View File

@@ -242,6 +242,7 @@ class PresenceCtrl(LionCog):
await self.data.init() await self.data.init()
if (leo_setting_cog := self.bot.get_cog('LeoSettings')) is not None: if (leo_setting_cog := self.bot.get_cog('LeoSettings')) is not None:
leo_setting_cog.bot_setting_groups.append(self.settings) leo_setting_cog.bot_setting_groups.append(self.settings)
self.crossload_group(self.leo_group, leo_setting_cog.leo_group)
await self.reload_presence() await self.reload_presence()
self.update_listeners() self.update_listeners()
@@ -372,7 +373,12 @@ class PresenceCtrl(LionCog):
"Unhandled exception occurred running client presence update loop. Closing loop." "Unhandled exception occurred running client presence update loop. Closing loop."
) )
@cmds.hybrid_command( @LionCog.placeholder_group
@cmds.hybrid_group('configure', with_app_command=False)
async def leo_group(self, ctx: LionContext):
...
@leo_group.command(
name="presence", name="presence",
description="Globally set the bot status and activity." description="Globally set the bot status and activity."
) )

View File

@@ -291,7 +291,7 @@ class TasklistCog(LionCog):
name=t(_p( name=t(_p(
'argtype:taskid|error:no_tasks', 'argtype:taskid|error:no_tasks',
"Tasklist empty! No matching tasks." "Tasklist empty! No matching tasks."
)), ))[:100],
value=partial value=partial
) )
] ]
@@ -319,7 +319,7 @@ class TasklistCog(LionCog):
if matching: if matching:
# If matches were found, assume user wants one of the matches # If matches were found, assume user wants one of the matches
options = [ options = [
appcmds.Choice(name=task_string, value=label) appcmds.Choice(name=task_string[:100], value=label)
for label, task_string in matching for label, task_string in matching
] ]
elif multi and partial.lower().strip() in ('-', 'all'): elif multi and partial.lower().strip() in ('-', 'all'):
@@ -328,7 +328,7 @@ class TasklistCog(LionCog):
name=t(_p( name=t(_p(
'argtype:taskid|match:all', 'argtype:taskid|match:all',
"All tasks" "All tasks"
)), ))[:100],
value='-' value='-'
) )
] ]
@@ -353,7 +353,7 @@ class TasklistCog(LionCog):
multi_name = f"{partial[:remaining-1]} {error}" multi_name = f"{partial[:remaining-1]} {error}"
multi_option = appcmds.Choice( multi_option = appcmds.Choice(
name=multi_name, name=multi_name[:100],
value=partial value=partial
) )
options = [multi_option] options = [multi_option]
@@ -371,7 +371,7 @@ class TasklistCog(LionCog):
if not matching: if not matching:
matching = [(label, task) for label, task in labels if last_split.lower() in task.lower()] matching = [(label, task) for label, task in labels if last_split.lower() in task.lower()]
options.extend( options.extend(
appcmds.Choice(name=task_string, value=label) appcmds.Choice(name=task_string[:100], value=label)
for label, task_string in matching for label, task_string in matching
) )
else: else:
@@ -380,7 +380,7 @@ class TasklistCog(LionCog):
name=t(_p( name=t(_p(
'argtype:taskid|error:no_matching', 'argtype:taskid|error:no_matching',
"No tasks matching '{partial}'!", "No tasks matching '{partial}'!",
)).format(partial=partial[:100]), )).format(partial=partial[:100])[:100],
value=partial value=partial
) )
] ]

View File

@@ -728,7 +728,7 @@ class TasklistUI(BasePager):
) )
try: try:
await press.user.send(contents, file=file, silent=True) await press.user.send(contents, file=file, silent=True)
except discord.HTTPClient: except discord.HTTPException:
fp.seek(0) fp.seek(0)
file = discord.File(fp, filename='tasklist.md') file = discord.File(fp, filename='tasklist.md')
await press.followup.send( await press.followup.send(
@@ -736,7 +736,7 @@ class TasklistUI(BasePager):
'ui:tasklist|button:save|error:dms', 'ui:tasklist|button:save|error:dms',
"Could not DM you! Do you have me blocked? Tasklist attached below." "Could not DM you! Do you have me blocked? Tasklist attached below."
)), )),
file=file file=file,
) )
else: else:
fp.seek(0) fp.seek(0)

View File

@@ -393,7 +393,7 @@ class VideoCog(LionCog):
only_warn = True only_warn = True
# Ack based on ticket created # Ack based on ticket created
alert_ref = message.to_reference(fail_if_not_exists=False) alert_ref = message.to_reference(fail_if_not_exists=False) if message else None
if only_warn: if only_warn:
# TODO: Warn ticket # TODO: Warn ticket
warning = discord.Embed( warning = discord.Embed(

View File

@@ -237,7 +237,7 @@ class ChannelSetting(Generic[ParentID, CT], InteractiveSetting[ParentID, int, CT
_selector_placeholder = "Select a Channel" _selector_placeholder = "Select a Channel"
channel_types: list[discord.ChannelType] = [] channel_types: list[discord.ChannelType] = []
_allow_object = True _allow_object = False
@classmethod @classmethod
def _data_from_value(cls, parent_id, value, **kwargs): def _data_from_value(cls, parent_id, value, **kwargs):
@@ -368,7 +368,7 @@ class RoleSetting(InteractiveSetting[ParentID, int, Union[discord.Role, discord.
_accepts = _p('settype:role|accepts', "A role name or id") _accepts = _p('settype:role|accepts', "A role name or id")
_selector_placeholder = "Select a Role" _selector_placeholder = "Select a Role"
_allow_object = True _allow_object = False
@classmethod @classmethod
def _get_guildid(cls, parent_id: int, **kwargs) -> int: def _get_guildid(cls, parent_id: int, **kwargs) -> int:
@@ -915,7 +915,7 @@ class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]):
name=t(_p( name=t(_p(
'set_type:timezone|acmpl|no_matching', 'set_type:timezone|acmpl|no_matching',
"No timezones matching '{input}'!" "No timezones matching '{input}'!"
)).format(input=partial), )).format(input=partial)[:100],
value=partial value=partial
) )
] ]
@@ -930,7 +930,7 @@ class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]):
"{tz} (Currently {now})" "{tz} (Currently {now})"
)).format(tz=tz, now=nowstr) )).format(tz=tz, now=nowstr)
choice = appcmds.Choice( choice = appcmds.Choice(
name=name, name=name[:100],
value=tz value=tz
) )
choices.append(choice) choices.append(choice)

View File

@@ -236,7 +236,7 @@ class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]):
Callable[[ParentID, SettingData], Coroutine[Any, Any, None]] Callable[[ParentID, SettingData], Coroutine[Any, Any, None]]
""" """
if self._event is not None and (bot := ctx_bot.get()) is not None: if self._event is not None and (bot := ctx_bot.get()) is not None:
bot.dispatch(self._event, self.parent_id, self.data) bot.dispatch(self._event, self.parent_id, self)
def get_listener(self, key): def get_listener(self, key):
return self._listeners_.get(key, None) return self._listeners_.get(key, None)

View File

@@ -13,6 +13,7 @@ from meta.errors import UserInputError
from meta.logger import log_wrap, logging_context from meta.logger import log_wrap, logging_context
from meta.sharding import THIS_SHARD from meta.sharding import THIS_SHARD
from meta.app import appname from meta.app import appname
from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel
from utils.lib import utc_now, error_embed from utils.lib import utc_now, error_embed
from wards import low_management_ward, sys_admin_ward from wards import low_management_ward, sys_admin_ward
@@ -42,10 +43,14 @@ class TextTrackerCog(LionCog):
self.data = bot.db.load_registry(TextTrackerData()) self.data = bot.db.load_registry(TextTrackerData())
self.settings = TextTrackerSettings() self.settings = TextTrackerSettings()
self.global_settings = TextTrackerGlobalSettings() self.global_settings = TextTrackerGlobalSettings()
self.monitor = ComponentMonitor('TextTracker', self._monitor)
self.babel = babel self.babel = babel
self.sessionq = asyncio.Queue(maxsize=0) self.sessionq = asyncio.Queue(maxsize=0)
self.ready = asyncio.Event()
self.errors = 0
# Map of ongoing text sessions # Map of ongoing text sessions
# guildid -> (userid -> TextSession) # guildid -> (userid -> TextSession)
self.ongoing = defaultdict(dict) self.ongoing = defaultdict(dict)
@@ -54,7 +59,41 @@ class TextTrackerCog(LionCog):
self.untracked_channels = self.settings.UntrackedTextChannels._cache self.untracked_channels = self.settings.UntrackedTextChannels._cache
async def _monitor(self):
state = (
"<"
"TextTracker"
" ready={ready}"
" queued={queued}"
" errors={errors}"
" running={running}"
" consumer={consumer}"
">"
)
data = dict(
ready=self.ready.is_set(),
queued=self.sessionq.qsize(),
errors=self.errors,
running=sum(len(usessions) for usessions in self.ongoing.values()),
consumer="'Running'" if (self._consumer_task and not self._consumer_task.done()) else "'Not Running'",
)
if not self.ready.is_set():
level = StatusLevel.STARTING
info = f"(STARTING) Not initialised. {state}"
elif not self._consumer_task:
level = StatusLevel.ERRORED
info = f"(ERROR) Consumer task not running. {state}"
elif self.errors > 1:
level = StatusLevel.UNSURE
info = f"(UNSURE) Errors occurred while consuming. {state}"
else:
level = StatusLevel.OKAY
info = f"(OK) Message tracking operational. {state}"
return ComponentStatus(level, info, info, data)
async def cog_load(self): async def cog_load(self):
self.bot.system_monitor.add_component(self.monitor)
await self.data.init() await self.data.init()
self.bot.core.guild_config.register_model_setting(self.settings.XPPerPeriod) self.bot.core.guild_config.register_model_setting(self.settings.XPPerPeriod)
@@ -83,6 +122,7 @@ class TextTrackerCog(LionCog):
await self.initialise() await self.initialise()
async def cog_unload(self): async def cog_unload(self):
self.ready.clear()
if self._consumer_task is not None: if self._consumer_task is not None:
self._consumer_task.cancel() self._consumer_task.cancel()
@@ -104,7 +144,7 @@ class TextTrackerCog(LionCog):
await self.bot.core.lions.fetch_member(session.guildid, session.userid) await self.bot.core.lions.fetch_member(session.guildid, session.userid)
self.sessionq.put_nowait(session) self.sessionq.put_nowait(session)
@log_wrap(stack=['Text Sessions', 'Message Event']) @log_wrap(stack=['Text Sessions', 'Consumer'])
async def _session_consumer(self): async def _session_consumer(self):
""" """
Process completed sessions in batches of length `batchsize`. Process completed sessions in batches of length `batchsize`.
@@ -132,6 +172,7 @@ class TextTrackerCog(LionCog):
logger.exception( logger.exception(
"Unknown exception processing batch of text sessions! Discarding and continuing." "Unknown exception processing batch of text sessions! Discarding and continuing."
) )
self.errors += 1
batch = [] batch = []
counter = 0 counter = 0
last_time = time.monotonic() last_time = time.monotonic()
@@ -157,6 +198,9 @@ class TextTrackerCog(LionCog):
# Batch-fetch lguilds # Batch-fetch lguilds
lguilds = await self.bot.core.lions.fetch_guilds(*{session.guildid for session in batch}) lguilds = await self.bot.core.lions.fetch_guilds(*{session.guildid for session in batch})
await self.bot.core.lions.fetch_members(
*((session.guildid, session.userid) for session in batch)
)
# Build data # Build data
rows = [] rows = []
@@ -202,9 +246,11 @@ class TextTrackerCog(LionCog):
""" """
Launch the session consumer. Launch the session consumer.
""" """
self.ready.clear()
if self._consumer_task and not self._consumer_task.cancelled(): if self._consumer_task and not self._consumer_task.cancelled():
self._consumer_task.cancel() self._consumer_task.cancel()
self._consumer_task = asyncio.create_task(self._session_consumer()) self._consumer_task = asyncio.create_task(self._session_consumer(), name='text-session-consumer')
self.ready.set()
logger.info("Launched text session consumer.") logger.info("Launched text session consumer.")
@LionCog.listener('on_message') @LionCog.listener('on_message')

View File

@@ -301,7 +301,7 @@ class TextTrackerData(Registry):
FROM text_sessions FROM text_sessions
WHERE guildid = %s AND start_time >= %s WHERE guildid = %s AND start_time >= %s
GROUP BY userid GROUP BY userid
ORDER BY ORDER BY user_total DESC
""" """
) )
async with cls._connector.connection() as conn: async with cls._connector.connection() as conn:
@@ -325,7 +325,7 @@ class TextTrackerData(Registry):
FROM text_sessions FROM text_sessions
WHERE guildid = %s WHERE guildid = %s
GROUP BY userid GROUP BY userid
ORDER BY ORDER BY user_total DESC
""" """
) )
async with cls._connector.connection() as conn: async with cls._connector.connection() as conn:

View File

@@ -1,17 +1,17 @@
from typing import Optional from typing import Optional
import asyncio import asyncio
import datetime as dt import datetime as dt
from collections import defaultdict
import discord import discord
from discord.ext import commands as cmds from discord.ext import commands as cmds
from discord import app_commands as appcmds from discord import app_commands as appcmds
from data import Condition
from meta import LionBot, LionCog, LionContext from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError from meta.logger import log_wrap
from meta.logger import log_wrap, logging_context
from meta.sharding import THIS_SHARD from meta.sharding import THIS_SHARD
from utils.lib import utc_now, error_embed from meta.monitor import ComponentMonitor, ComponentStatus, StatusLevel
from utils.lib import utc_now
from core.lion_guild import VoiceMode from core.lion_guild import VoiceMode
from wards import low_management_ward, moderator_ctxward from wards import low_management_ward, moderator_ctxward
@@ -35,6 +35,7 @@ class VoiceTrackerCog(LionCog):
self.data = bot.db.load_registry(VoiceTrackerData()) self.data = bot.db.load_registry(VoiceTrackerData())
self.settings = VoiceTrackerSettings() self.settings = VoiceTrackerSettings()
self.babel = babel self.babel = babel
self.monitor = ComponentMonitor('VoiceTracker', self._monitor)
# State # State
# Flag indicating whether local voice sessions have been initialised # Flag indicating whether local voice sessions have been initialised
@@ -44,7 +45,77 @@ class VoiceTrackerCog(LionCog):
self.untracked_channels = self.settings.UntrackedChannels._cache self.untracked_channels = self.settings.UntrackedChannels._cache
self.active_sessions = VoiceSession._active_sessions_
async def _monitor(self):
state = (
"<"
"VoiceTracker"
" initialised={initialised}"
" active={active}"
" pending={pending}"
" ongoing={ongoing}"
" locked={locked}"
" actual={actual}"
" channels={channels}"
" cached={cached}"
" initial_event={initial_event}"
" lock={lock}"
">"
)
data = dict(
initialised=self.initialised.is_set(),
active=0,
pending=0,
ongoing=0,
locked=0,
actual=0,
channels=0,
cached=len(VoiceSession._sessions_),
initial_event=self.initialised,
lock=self.tracking_lock
)
channels = set()
for tguild in self.active_sessions.values():
for session in tguild.values():
data['active'] += 1
if session.activity is SessionState.ONGOING:
data['ongoing'] += 1
elif session.activity is SessionState.PENDING:
data['pending'] += 1
if session.lock.locked():
data['locked'] += 1
if session.state:
channels.add(session.state.channelid)
data['channels'] = len(channels)
for guild in self.bot.guilds:
for channel in guild.voice_channels:
if not self.is_untracked(channel):
for member in channel.members:
if member.voice and not member.bot:
data['actual'] += 1
if not self.initialised.is_set():
level = StatusLevel.STARTING
info = f"(STARTING) Not initialised. {state}"
elif self.tracking_lock.locked():
level = StatusLevel.WAITING
info = f"(WAITING) Waiting for tracking lock. {state}"
elif data['actual'] != data['active']:
level = StatusLevel.UNSURE
info = f"(UNSURE) Actual sessions do not match active. {state}"
else:
level = StatusLevel.OKAY
info = f"(OK) Voice tracking operational. {state}"
return ComponentStatus(level, info, info, data)
async def cog_load(self): async def cog_load(self):
self.bot.system_monitor.add_component(self.monitor)
await self.data.init() await self.data.init()
self.bot.core.guild_config.register_model_setting(self.settings.HourlyReward) self.bot.core.guild_config.register_model_setting(self.settings.HourlyReward)
@@ -71,7 +142,8 @@ class VoiceTrackerCog(LionCog):
# Simultaneously! # Simultaneously!
... ...
def get_session(self, guildid, userid, **kwargs) -> VoiceSession: # ----- Cog API -----
def get_session(self, guildid, userid, **kwargs) -> Optional[VoiceSession]:
""" """
Get the VoiceSession for the given member. Get the VoiceSession for the given member.
@@ -91,6 +163,197 @@ class VoiceTrackerCog(LionCog):
untracked = False untracked = False
return untracked return untracked
@log_wrap(action='load sessions')
async def _load_sessions(self,
states: dict[tuple[int, int], TrackedVoiceState],
ongoing: list[VoiceTrackerData.VoiceSessionsOngoing]):
"""
Load voice sessions from provided states and ongoing data.
Provided data may cross multiple guilds.
Assumes all states which do not have data should be started.
Assumes all ongoing data which does not have states should be ended.
Assumes untracked channel data is up to date.
"""
OngoingData = VoiceTrackerData.VoiceSessionsOngoing
# Compute time to end complete sessions
now = utc_now()
last_update = max((row.last_update for row in ongoing), default=now)
end_at = min(last_update + dt.timedelta(seconds=3600), now)
# Bulk fetches for voice-active members and guilds
active_memberids = list(states.keys())
active_guildids = set(gid for gid, _ in states)
if states:
lguilds = await self.bot.core.lions.fetch_guilds(*active_guildids)
await self.bot.core.lions.fetch_members(*active_memberids)
tracked_today_data = await self.data.VoiceSessions.multiple_voice_tracked_since(
*((guildid, userid, lguilds[guildid].today) for guildid, userid in active_memberids)
)
tracked_today = {(row['guildid'], row['userid']): row['tracked'] for row in tracked_today_data}
else:
lguilds = {}
tracked_today = {}
# Zip session information together by memberid keys
sessions: dict[tuple[int, int], tuple[Optional[TrackedVoiceState], Optional[OngoingData]]] = {}
for row in ongoing:
key = (row.guildid, row.userid)
sessions[key] = (states.pop(key, None), row)
for key, state in states.items():
sessions[key] = (state, None)
# Now split up session information to fill action maps
close_ongoing = []
update_ongoing = []
create_ongoing = []
expiries = {}
load_sessions = []
schedule_sessions = {}
for (gid, uid), (state, data) in sessions.items():
if state is not None:
# Member is active
if data is not None and data.channelid != state.channelid:
# Ongoing session does not match active state
# Close the session, but still create/schedule the state
close_ongoing.append((gid, uid, end_at))
data = None
# Now create/update/schedule active session
# Also create/update data if required
lguild = lguilds[gid]
tomorrow = lguild.today + dt.timedelta(days=1)
cap = lguild.config.get('daily_voice_cap').value
tracked = tracked_today[gid, uid]
hourly_rate = await self._calculate_rate(gid, uid, state)
if tracked >= cap:
# Active session is already over cap
# Stop ongoing if it exists, and schedule next session start
delay = (tomorrow - now).total_seconds()
start_time = tomorrow
expiry = tomorrow + dt.timedelta(seconds=cap)
schedule_sessions[(gid, uid)] = (delay, start_time, expiry, state, hourly_rate)
if data is not None:
close_ongoing.append((
gid, uid,
max(now - dt.timedelta(seconds=tracked - cap), data.last_update)
))
else:
# Active session, update/create data
expiry = now + dt.timedelta(seconds=(cap - tracked))
if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap)
expiries[(gid, uid)] = expiry
if data is not None:
update_ongoing.append((gid, uid, now, state.stream, state.video, hourly_rate))
else:
create_ongoing.append((
gid, uid, state.channelid, now, now, state.stream, state.video, hourly_rate
))
elif data is not None:
# Ongoing data has no state, close the session
close_ongoing.append((gid, uid, end_at))
# Close data that needs closing
if close_ongoing:
logger.info(
f"Ending {len(close_ongoing)} ongoing voice sessions with no matching voice state."
)
await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*close_ongoing)
# Update data that needs updating
if update_ongoing:
logger.info(
f"Continuing {len(update_ongoing)} ongoing voice sessions with matching voice state."
)
rows = await self.data.VoiceSessionsOngoing.update_voice_sessions_at(*update_ongoing)
load_sessions.extend(rows)
# Create data that needs creating
if create_ongoing:
logger.info(
f"Creating {len(create_ongoing)} voice sessions from new voice states."
)
# First ensure the tracked channels exist
cids = set((item[2], item[0]) for item in create_ongoing)
await self.data.TrackedChannel.fetch_multiple(*cids)
# Then create the sessions
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
'live_video', 'hourly_coins'),
*create_ongoing
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
load_sessions.extend(rows)
# Create sessions from ongoing, with expiry
for row in load_sessions:
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
# Schedule starting sessions
for (gid, uid), args in schedule_sessions.items():
session = VoiceSession.get(self.bot, gid, uid)
await session.schedule_start(*args)
logger.info(
f"Successfully loaded {len(load_sessions)} and scheduled {len(schedule_sessions)} voice sessions."
)
@log_wrap(action='refresh guild sessions')
async def refresh_guild_sessions(self, guild: discord.Guild):
"""
Idempotently refresh all guild voice sessions in the given guild.
Essentially a lighter version of `initialise`.
"""
# TODO: There is a very small potential window for a race condition here
# Since we do not have a version of 'handle_events' for the guild
# We may actually handle events before starting refresh
# Causing sessions to have invalid state.
# If this becomes an actual problem, implement an `ignore_guilds` set flag of some form...
logger.debug(f"Beginning voice state refresh for <gid: {guild.id}>")
async with self.tracking_lock:
# TODO: Add a 'lock holder' attribute which is readable by the monitor
logger.debug(f"Voice state refresh for <gid: {guild.id}> is past lock")
# Deactivate any ongoing session tasks in this guild
active = self.active_sessions.pop(guild.id, {}).values()
for session in active:
session.cancel()
# Update untracked channel information for this guild
self.untracked_channels.pop(guild.id, None)
await self.settings.UntrackedChannels.get(guild.id)
# Read tracked voice states
states = {}
for channel in guild.voice_channels:
if not self.is_untracked(channel):
for member in channel.members:
if member.voice and not member.bot:
state = TrackedVoiceState.from_voice_state(member.voice)
states[(guild.id, member.id)] = state
logger.debug(f"Loaded {len(states)} tracked voice states for <gid: {guild.id}>.")
# Read ongoing session data
ongoing = await self.data.VoiceSessionsOngoing.fetch_where(guildid=guild.id)
logger.debug(
f"Loaded {len(ongoing)} ongoing voice sessions from data for <gid: {guild.id}>. Beginning reload."
)
await self._load_sessions(states, ongoing)
logger.info(
f"Completed guild voice session reload for <gid: {guild.id}> "
f"with '{len(self.active_sessions[guild.id])}' active sessions."
)
# ----- Event Handlers -----
@LionCog.listener('on_ready') @LionCog.listener('on_ready')
@log_wrap(action='Init Voice Sessions') @log_wrap(action='Init Voice Sessions')
async def initialise(self): async def initialise(self):
@@ -99,192 +362,54 @@ class VoiceTrackerCog(LionCog):
Ends ongoing sessions for members who are not in the given voice channel. Ends ongoing sessions for members who are not in the given voice channel.
""" """
# First take the tracking lock logger.info("Beginning voice session state initialisation. Disabling voice event handling.")
# Ensures current event handling completes before re-initialisation # If `on_ready` is called, that means we are initialising
# or we missed events and need to re-initialise.
# Start ignoring events because they may be working on stale or partial state
self.handle_events = False
# Services which read our cache should wait for initialisation before taking the lock
self.initialised.clear()
# Wait for running events to complete
# And make sure future events will be processed after initialisation
# Note only events occurring after our voice state snapshot will be processed
async with self.tracking_lock: async with self.tracking_lock:
logger.info("Reloading ongoing voice sessions") # Deactivate all ongoing sessions
active = [session for gsessions in self.active_sessions.values() for session in gsessions.values()]
for session in active:
session.cancel()
self.active_sessions.clear()
# Also clear the session registry cache
VoiceSession._sessions_.clear()
# Refresh untracked information for all guilds we are in
await self.settings.UntrackedChannels.setup(self.bot)
logger.debug("Disabling voice state event handling.")
self.handle_events = False
self.initialised.clear()
# Read and save the tracked voice states of all visible voice channels # Read and save the tracked voice states of all visible voice channels
voice_members = {} # (guildid, userid) -> TrackedVoiceState states = {}
voice_guilds = set()
for guild in self.bot.guilds: for guild in self.bot.guilds:
untracked = self.untracked_channels.get(guild.id, ())
for channel in guild.voice_channels: for channel in guild.voice_channels:
if channel.id in untracked: if not self.is_untracked(channel):
continue for member in channel.members:
if channel.category_id and channel.category_id in untracked: if member.voice and not member.bot:
continue state = TrackedVoiceState.from_voice_state(member.voice)
states[(guild.id, member.id)] = state
for member in channel.members: logger.info(
if member.bot: f"Saved voice snapshot with {len(states)} tracked states. Re-enabling voice event handling."
continue )
voice_members[(guild.id, member.id)] = TrackedVoiceState.from_voice_state(member.voice)
voice_guilds.add(guild.id)
logger.debug(f"Cached {len(voice_members)} members from voice channels.")
self.handle_events = True self.handle_events = True
logger.debug("Re-enabled voice state event handling.")
# Iterate through members with current ongoing sessions # Load ongoing session data for the entire shard
# End or update sessions as needed, based on saved tracked state ongoing = await self.data.VoiceSessionsOngoing.fetch_where(THIS_SHARD)
ongoing_rows = await self.data.VoiceSessionsOngoing.fetch_where( logger.info(
guildid=[guild.id for guild in self.bot.guilds] f"Retrieved {len(ongoing)} ongoing voice sessions from data. Beginning reload."
) )
logger.debug(
f"Loaded {len(ongoing_rows)} ongoing sessions from data. Splitting into complete and incomplete."
)
complete = []
incomplete = []
incomplete_guildids = set()
# Compute time to end complete sessions await self._load_sessions(states, ongoing)
now = utc_now()
last_update = max((row.last_update for row in ongoing_rows), default=now)
end_at = min(last_update + dt.timedelta(seconds=3600), now)
for row in ongoing_rows:
key = (row.guildid, row.userid)
state = voice_members.get(key, None)
untracked = self.untracked_channels.get(row.guildid, [])
if (
state
and state.channelid == row.channelid
and state.channelid not in untracked
and (ch := self.bot.get_channel(state.channelid)) is not None
and (not ch.category_id or ch.category_id not in untracked)
):
# Mark session as ongoing
incomplete.append((row, state))
incomplete_guildids.add(row.guildid)
voice_members.pop(key)
else:
# Mark session as complete
complete.append((row.guildid, row.userid, end_at))
# Load required guild data into cache
active_guildids = incomplete_guildids.union(voice_guilds)
if active_guildids:
await self.bot.core.data.Guild.fetch_where(guildid=tuple(active_guildids))
lguilds = {guildid: await self.bot.core.lions.fetch_guild(guildid) for guildid in active_guildids}
# Calculate tracked_today for members with ongoing sessions
active_members = set((row.guildid, row.userid) for row, _ in incomplete)
active_members.update(voice_members.keys())
if active_members:
tracked_today_data = await self.data.VoiceSessions.multiple_voice_tracked_since(
*((guildid, userid, lguilds[guildid].today) for guildid, userid in active_members)
)
else:
tracked_today_data = []
tracked_today = {(row['guildid'], row['userid']): row['tracked'] for row in tracked_today_data}
if incomplete:
# Note that study_time_since _includes_ ongoing sessions in its calculation
# So expiry times are "time left today until cap" or "tomorrow + cap"
to_load = [] # (session_data, expiry_time)
to_update = [] # (guildid, userid, update_at, stream, video, hourly_rate)
for session_data, state in incomplete:
# Calculate expiry times
lguild = lguilds[session_data.guildid]
cap = lguild.config.get('daily_voice_cap').value
tracked = tracked_today[(session_data.guildid, session_data.userid)]
if tracked >= cap:
# Already over cap
complete.append((
session_data.guildid,
session_data.userid,
max(now + dt.timedelta(seconds=tracked - cap), session_data.last_update)
))
else:
tomorrow = lguild.today + dt.timedelta(days=1)
expiry = now + dt.timedelta(seconds=(cap - tracked))
if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap)
to_load.append((session_data, expiry))
# TODO: Probably better to do this by batch
# Could force all bonus calculators to accept list of members
hourly_rate = await self._calculate_rate(session_data.guildid, session_data.userid, state)
to_update.append((
session_data.guildid,
session_data.userid,
now,
state.stream,
state.video,
hourly_rate
))
# Run the updates, note that session_data uses registry pattern so will also update
if to_update:
await self.data.VoiceSessionsOngoing.update_voice_sessions_at(*to_update)
# Load the sessions
for data, expiry in to_load:
VoiceSession.from_ongoing(self.bot, data, expiry)
logger.info(f"Resumed {len(to_load)} ongoing voice sessions.")
if complete:
logger.info(f"Ending {len(complete)} out-of-date or expired study sessions.")
# Complete sessions just need a mass end_voice_session_at()
await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*complete)
# Then iterate through the saved states from tracked voice channels
# Start sessions if they don't already exist
if voice_members:
expiries = {} # (guildid, memberid) -> expiry time
to_create = [] # (guildid, userid, channelid, start_time, last_update, live_stream, live_video, rate)
for (guildid, userid), state in voice_members.items():
untracked = self.untracked_channels.get(guildid, [])
channel = self.bot.get_channel(state.channelid)
if (
channel
and channel.id not in untracked
and (not channel.category_id or channel.category_id not in untracked)
):
# State is from member in tracked voice channel
# Calculate expiry
lguild = lguilds[guildid]
cap = lguild.config.get('daily_voice_cap').value
tracked = tracked_today[(guildid, userid)]
if tracked < cap:
tomorrow = lguild.today + dt.timedelta(days=1)
expiry = now + dt.timedelta(seconds=(cap - tracked))
if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap)
expiries[(guildid, userid)] = expiry
hourly_rate = await self._calculate_rate(guildid, userid, state)
to_create.append((
guildid, userid,
state.channelid,
now, now,
state.stream, state.video,
hourly_rate
))
# Bulk create the ongoing sessions
if to_create:
# First ensure the lion members exist
await self.bot.core.lions.fetch_members(
*(item[:2] for item in to_create)
)
# Then ensure the TrackedChannels exist
cids = set((item[2], item[0]) for item in to_create)
await self.data.TrackedChannel.fetch_multiple(*cids)
# Then actually create the ongoing sessions
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
'live_video', 'hourly_coins'),
*to_create
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
for row in rows:
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
logger.info(f"Started {len(rows)} new voice sessions from voice channels!")
self.initialised.set() self.initialised.set()
@LionCog.listener("on_voice_state_update") @LionCog.listener("on_voice_state_update")
@@ -314,6 +439,9 @@ class VoiceTrackerCog(LionCog):
# If tracked state did not change, ignore event # If tracked state did not change, ignore event
return return
bchannel = before.channel if before else None
achannel = after.channel if after else None
# Take tracking lock # Take tracking lock
async with self.tracking_lock: async with self.tracking_lock:
# Fetch tracked member session state # Fetch tracked member session state
@@ -334,7 +462,7 @@ class VoiceTrackerCog(LionCog):
"Voice event does not match session information! " "Voice event does not match session information! "
f"Member '{member.name}' <uid:{member.id}> " f"Member '{member.name}' <uid:{member.id}> "
f"of guild '{member.guild.name}' <gid:{member.guild.id}> " f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
f"left channel '#{before.channel.name}' <cid:{leaving}> " f"left channel '{bchannel}' <cid:{leaving}> "
f"during voice session in channel <cid:{tstate.channelid}>!" f"during voice session in channel <cid:{tstate.channelid}>!"
) )
# Close (or cancel) active session # Close (or cancel) active session
@@ -344,16 +472,13 @@ class VoiceTrackerCog(LionCog):
" because they left the channel." " because they left the channel."
) )
await session.close() await session.close()
elif ( elif not self.is_untracked(bchannel):
leaving not in untracked and
not (before.channel.category_id and before.channel.category_id in untracked)
):
# Leaving tracked channel without an active session? # Leaving tracked channel without an active session?
logger.warning( logger.warning(
"Voice event does not match session information! " "Voice event does not match session information! "
f"Member '{member.name}' <uid:{member.id}> " f"Member '{member.name}' <uid:{member.id}> "
f"of guild '{member.guild.name}' <gid:{member.guild.id}> " f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
f"left tracked channel '#{before.channel.name}' <cid:{leaving}> " f"left tracked channel '{bchannel}' <cid:{leaving}> "
f"with no matching voice session!" f"with no matching voice session!"
) )
@@ -365,14 +490,11 @@ class VoiceTrackerCog(LionCog):
"Voice event does not match session information! " "Voice event does not match session information! "
f"Member '{member.name}' <uid:{member.id}> " f"Member '{member.name}' <uid:{member.id}> "
f"of guild '{member.guild.name}' <gid:{member.guild.id}> " f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
f"joined channel '#{after.channel.name}' <cid:{joining}> " f"joined channel '{achannel}' <cid:{joining}> "
f"during voice session in channel <cid:{tstate.channelid}>!" f"during voice session in channel <cid:{tstate.channelid}>!"
) )
await session.close() await session.close()
if ( if not self.is_untracked(achannel):
joining not in untracked and
not (after.channel.category_id and after.channel.category_id in untracked)
):
# If the channel they are joining is tracked, schedule a session start for them # If the channel they are joining is tracked, schedule a session start for them
delay, start, expiry = await self._session_boundaries_for(member.guild.id, member.id) delay, start, expiry = await self._session_boundaries_for(member.guild.id, member.id)
hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate) hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate)
@@ -380,7 +502,7 @@ class VoiceTrackerCog(LionCog):
logger.debug( logger.debug(
f"Scheduling voice session for member `{member.name}' <uid:{member.id}> " f"Scheduling voice session for member `{member.name}' <uid:{member.id}> "
f"in guild '{member.guild.name}' <gid: member.guild.id> " f"in guild '{member.guild.name}' <gid: member.guild.id> "
f"in channel '{after.channel.name}' <cid: {after.channel.id}>. " f"in channel '{achannel}' <cid: {after.channel.id}>. "
f"Session will start at {start}, expire at {expiry}, and confirm in {delay}." f"Session will start at {start}, expire at {expiry}, and confirm in {delay}."
) )
await session.schedule_start(delay, start, expiry, astate, hourly_rate) await session.schedule_start(delay, start, expiry, astate, hourly_rate)
@@ -391,116 +513,24 @@ class VoiceTrackerCog(LionCog):
hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate) hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate)
await session.update(new_state=astate, new_rate=hourly_rate) await session.update(new_state=astate, new_rate=hourly_rate)
@LionCog.listener("on_guild_setting_update_untracked_channels") @LionCog.listener("on_guildset_untracked_channels")
async def update_untracked_channels(self, guildid, setting): @LionCog.listener("on_guildset_hourly_reward")
""" @LionCog.listener("on_guildset_hourly_live_bonus")
Close sessions in untracked channels, and recalculate previously untracked sessions @LionCog.listener("on_guildset_daily_voice_cap")
""" @LionCog.listener("on_guildset_timezone")
async def _event_refresh_guild(self, guildid: int, setting):
if not self.handle_events: if not self.handle_events:
return return
guild = self.bot.get_guild(guildid)
async with self.tracking_lock: if guild is None:
lguild = await self.bot.core.lions.fetch_guild(guildid) logger.warning(
guild = self.bot.get_guild(guildid) f"Voice tracker discarding '{setting.setting_id}' event for unknown guild <gid: {guildid}>."
if not guild: )
# Left guild while waiting on lock else:
return logger.debug(
cap = lguild.config.get('daily_voice_cap').value f"Voice tracker handling '{setting.setting_id}' event for guild <gid: {guildid}>."
untracked = self.untracked_channels.get(guildid, []) )
now = utc_now() await self.refresh_guild_sessions(guild)
# Iterate through active sessions, close any that are in untracked channels
active = VoiceSession._active_sessions_.get(guildid, {})
for session in list(active.values()):
if session.state.channelid in untracked:
await session.close()
# Iterate through voice members, open new sessions if needed
expiries = {}
to_create = []
for channel in guild.voice_channels:
if channel.id in untracked:
continue
for member in channel.members:
if self.get_session(guildid, member.id).activity:
# Already have an active session for this member
continue
userid = member.id
state = TrackedVoiceState.from_voice_state(member.voice)
# TODO: Take into account tracked_today time?
# TODO: Make a per-guild refresh function to stay DRY
tomorrow = lguild.today + dt.timedelta(days=1)
expiry = now + dt.timedelta(seconds=cap)
if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap)
expiries[(guildid, userid)] = expiry
hourly_rate = await self._calculate_rate(guildid, userid, state)
to_create.append((
guildid, userid,
state.channelid,
now, now,
state.stream, state.video,
hourly_rate
))
if to_create:
# Ensure LionMembers exist
await self.bot.core.lions.fetch_members(
*(item[:2] for item in to_create)
)
# Ensure TrackedChannels exist
cids = set((item[2], item[0]) for item in to_create)
await self.data.TrackedChannel.fetch_multiple(*cids)
# Create new sessions
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
'live_video', 'hourly_coins'),
*to_create
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
for row in rows:
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
logger.info(
f"Started {len(rows)} new voice sessions from voice members "
f"in previously untracked channels of guild '{guild.name}' <gid:{guildid}>."
)
@LionCog.listener("on_guild_setting_update_hourly_reward")
async def update_hourly_reward(self, guildid, setting):
if not self.handle_events:
return
async with self.tracking_lock:
sessions = VoiceSession._active_sessions_.get(guildid, {})
for session in list(sessions.values()):
hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state)
await session.update(new_rate=hourly_rate)
@LionCog.listener("on_guild_setting_update_hourly_live_bonus")
async def update_hourly_live_bonus(self, guildid, setting):
if not self.handle_events:
return
async with self.tracking_lock:
sessions = VoiceSession._active_sessions_.get(guildid, {})
for session in list(sessions.values()):
hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state)
await session.update(new_rate=hourly_rate)
@LionCog.listener("on_guild_setting_update_daily_voice_cap")
async def update_daily_voice_cap(self, guildid, setting):
# TODO: Guild daily_voice_cap setting triggers session expiry recalculation for all sessions
...
@LionCog.listener("on_guild_setting_update_timezone")
@log_wrap(action='Voice Track')
@log_wrap(action='Timezone Update')
async def update_timezone(self, guildid, setting):
# TODO: Guild timezone setting triggers studied_today cache rebuild
logger.info("Received dispatch event for timezone change!")
async def _calculate_rate(self, guildid, userid, state): async def _calculate_rate(self, guildid, userid, state):
""" """
@@ -522,7 +552,7 @@ class VoiceTrackerCog(LionCog):
return hourly_rate return hourly_rate
async def _session_boundaries_for(self, guildid: int, userid: int) -> tuple[int, dt.datetime, dt.datetime]: async def _session_boundaries_for(self, guildid: int, userid: int) -> tuple[float, dt.datetime, dt.datetime]:
""" """
Compute when the next session for this member should start and expire. Compute when the next session for this member should start and expire.
@@ -539,7 +569,7 @@ class VoiceTrackerCog(LionCog):
""" """
lguild = await self.bot.core.lions.fetch_guild(guildid) lguild = await self.bot.core.lions.fetch_guild(guildid)
now = lguild.now now = lguild.now
tomorrow = now + dt.timedelta(days=1) tomorrow = lguild.today + dt.timedelta(days=1)
studied_today = await self.fetch_tracked_today(guildid, userid) studied_today = await self.fetch_tracked_today(guildid, userid)
cap = lguild.config.get('daily_voice_cap').value cap = lguild.config.get('daily_voice_cap').value
@@ -552,7 +582,7 @@ class VoiceTrackerCog(LionCog):
delay = 20 delay = 20
expiry = start_time + dt.timedelta(seconds=cap) expiry = start_time + dt.timedelta(seconds=cap)
if expiry >= tomorrow: if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap) expiry = tomorrow + dt.timedelta(seconds=cap)
return (delay, start_time, expiry) return (delay, start_time, expiry)
@@ -574,61 +604,9 @@ class VoiceTrackerCog(LionCog):
Initialise and start required new sessions from voice channel members when we join a guild. Initialise and start required new sessions from voice channel members when we join a guild.
""" """
if not self.handle_events: if not self.handle_events:
# Initialisation will take care of it for us
return return
await self.refresh_guild_sessions(guild)
async with self.tracking_lock:
guildid = guild.id
lguild = await self.bot.core.lions.fetch_guild(guildid)
cap = lguild.config.get('daily_voice_cap').value
untracked = self.untracked_channels.get(guildid, [])
now = utc_now()
expiries = {}
to_create = []
for channel in guild.voice_channels:
if channel.id in untracked:
continue
for member in channel.members:
userid = member.id
state = TrackedVoiceState.from_voice_state(member.voice)
tomorrow = lguild.today + dt.timedelta(days=1)
expiry = now + dt.timedelta(seconds=cap)
if expiry > tomorrow:
expiry = tomorrow + dt.timedelta(seconds=cap)
expiries[(guildid, userid)] = expiry
hourly_rate = await self._calculate_rate(guildid, userid, state)
to_create.append((
guildid, userid,
state.channelid,
now, now,
state.stream, state.video,
hourly_rate
))
if to_create:
# Ensure LionMembers exist
await self.bot.core.lions.fetch_members(
*(item[:2] for item in to_create)
)
# Ensure TrackedChannels exist
cids = set((item[2], item[0]) for item in to_create)
await self.data.TrackedChannel.fetch_multiple(*cids)
# Create new sessions
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
'live_video', 'hourly_coins'),
*to_create
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
for row in rows:
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
logger.info(
f"Started {len(rows)} new voice sessions from voice members "
f"in new guild '{guild.name}' <gid:{guildid}>."
)
@LionCog.listener("on_guild_remove") @LionCog.listener("on_guild_remove")
@log_wrap(action='Leave Guild Voice Sessions') @log_wrap(action='Leave Guild Voice Sessions')
@@ -645,10 +623,7 @@ class VoiceTrackerCog(LionCog):
now = utc_now() now = utc_now()
to_close = [] # (guildid, userid, _at) to_close = [] # (guildid, userid, _at)
for session in sessions.values(): for session in sessions.values():
if session.start_task is not None: session.cancel()
session.start_task.cancel()
if session.expiry_task is not None:
session.expiry_task.cancel()
to_close.append((session.guildid, session.userid, now)) to_close.append((session.guildid, session.userid, now))
if to_close: if to_close:
await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*to_close) await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*to_close)

View File

@@ -108,7 +108,7 @@ class VoiceTrackerData(Registry):
video_duration = Integer() video_duration = Integer()
stream_duration = Integer() stream_duration = Integer()
coins_earned = Integer() coins_earned = Integer()
last_update = Integer() last_update = Timestamp()
live_stream = Bool() live_stream = Bool()
live_video = Bool() live_video = Bool()
hourly_coins = Integer() hourly_coins = Integer()
@@ -154,7 +154,7 @@ class VoiceTrackerData(Registry):
async def update_voice_session_at( async def update_voice_session_at(
cls, guildid: int, userid: int, _at: dt.datetime, cls, guildid: int, userid: int, _at: dt.datetime,
stream: bool, video: bool, rate: float stream: bool, video: bool, rate: float
) -> int: ):
async with cls._connector.connection() as conn: async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(

View File

@@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, overload, Literal
from enum import IntEnum from enum import IntEnum
from collections import defaultdict from collections import defaultdict
import datetime as dt import datetime as dt
@@ -73,11 +73,14 @@ class VoiceSession:
'start_task', 'expiry_task', 'start_task', 'expiry_task',
'data', 'state', 'hourly_rate', 'data', 'state', 'hourly_rate',
'_tag', '_start_time', '_tag', '_start_time',
'lock',
'__weakref__' '__weakref__'
) )
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping _sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
# Maintains strong references to active sessions
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None): def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
self.bot = bot self.bot = bot
@@ -96,6 +99,17 @@ class VoiceSession:
self._tag = None self._tag = None
self._start_time = None self._start_time = None
# Member session lock
# Ensures state changes are atomic and serialised
self.lock = asyncio.Lock()
def cancel(self):
if self.start_task is not None:
self.start_task.cancel()
if self.expiry_task is not None:
self.expiry_task.cancel()
self._active_sessions_[self.guildid].pop(self.userid, None)
@property @property
def tag(self) -> Optional[str]: def tag(self) -> Optional[str]:
if self.data: if self.data:
@@ -121,6 +135,16 @@ class VoiceSession:
else: else:
return SessionState.INACTIVE return SessionState.INACTIVE
@overload
@classmethod
def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[False]) -> Optional['VoiceSession']:
...
@overload
@classmethod
def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[True] = True) -> 'VoiceSession':
...
@classmethod @classmethod
def get(cls, bot: LionBot, guildid: int, userid: int, create=True) -> Optional['VoiceSession']: def get(cls, bot: LionBot, guildid: int, userid: int, create=True) -> Optional['VoiceSession']:
""" """
@@ -149,11 +173,12 @@ class VoiceSession:
return self return self
async def set_tag(self, new_tag): async def set_tag(self, new_tag):
if self.activity is SessionState.INACTIVE: async with self.lock:
raise ValueError("Cannot set tag on an inactive voice session.") if self.activity is SessionState.INACTIVE:
self._tag = new_tag raise ValueError("Cannot set tag on an inactive voice session.")
if self.data is not None: self._tag = new_tag
await self.data.update(tag=new_tag) if self.data is not None:
await self.data.update(tag=new_tag)
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate): async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
""" """
@@ -167,6 +192,7 @@ class VoiceSession:
self.start_task = asyncio.create_task(self._start_after(delay, start_time)) self.start_task = asyncio.create_task(self._start_after(delay, start_time))
self.schedule_expiry(expire_time) self.schedule_expiry(expire_time)
self._active_sessions_[self.guildid][self.userid] = self
async def _start_after(self, delay: int, start_time: dt.datetime): async def _start_after(self, delay: int, start_time: dt.datetime):
""" """
@@ -174,36 +200,36 @@ class VoiceSession:
Creates the tracked_channel if required. Creates the tracked_channel if required.
""" """
self._active_sessions_[self.guildid][self.userid] = self
await asyncio.sleep(delay) await asyncio.sleep(delay)
logger.debug( async with self.lock:
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> " logger.info(
f"and channel <cid:{self.state.channelid}>." f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
) f"and channel <cid:{self.state.channelid}>."
# Create the lion if required )
await self.bot.core.lions.fetch_member(self.guildid, self.userid) # Create the lion if required
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
# Create the tracked channel if required # Create the tracked channel if required
await self.registry.TrackedChannel.fetch_or_create( await self.registry.TrackedChannel.fetch_or_create(
self.state.channelid, guildid=self.guildid, deleted=False self.state.channelid, guildid=self.guildid, deleted=False
) )
# Insert an ongoing_session with the correct state, set data # Insert an ongoing_session with the correct state, set data
state = self.state state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create( self.data = await self.registry.VoiceSessionsOngoing.create(
guildid=self.guildid, guildid=self.guildid,
userid=self.userid, userid=self.userid,
channelid=state.channelid, channelid=state.channelid,
start_time=start_time, start_time=start_time,
last_update=start_time, last_update=start_time,
live_stream=state.stream, live_stream=state.stream,
live_video=state.video, live_video=state.video,
hourly_coins=self.hourly_rate, hourly_coins=self.hourly_rate,
tag=self._tag tag=self._tag
) )
self.bot.dispatch('voice_session_start', self.data) self.bot.dispatch('voice_session_start', self.data)
self.start_task = None self.start_task = None
def schedule_expiry(self, expire_time): def schedule_expiry(self, expire_time):
""" """
@@ -258,33 +284,36 @@ class VoiceSession:
""" """
Close the session, or cancel the pending session. Idempotent. Close the session, or cancel the pending session. Idempotent.
""" """
if self.activity is SessionState.ONGOING: async with self.lock:
# End the ongoing session if self.activity is SessionState.ONGOING:
now = utc_now() # End the ongoing session
await self.data.close_study_session_at(self.guildid, self.userid, now) now = utc_now()
await self.data.close_study_session_at(self.guildid, self.userid, now)
# TODO: Something a bit saner/safer.. dispatch the finished session instead? # TODO: Something a bit saner/safer.. dispatch the finished session instead?
self.bot.dispatch('voice_session_end', self.data, now) self.bot.dispatch('voice_session_end', self.data, now)
# Rank update # Rank update
# TODO: Change to broadcasted event? # TODO: Change to broadcasted event?
rank_cog = self.bot.get_cog('RankCog') rank_cog = self.bot.get_cog('RankCog')
if rank_cog is not None: if rank_cog is not None:
asyncio.create_task(rank_cog.on_voice_session_complete( asyncio.create_task(rank_cog.on_voice_session_complete(
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0) (self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
)) ))
if self.start_task is not None: if self.start_task is not None:
self.start_task.cancel() self.start_task.cancel()
self.start_task = None self.start_task = None
if self.expiry_task is not None: if self.expiry_task is not None:
self.expiry_task.cancel() self.expiry_task.cancel()
self.expiry_task = None self.expiry_task = None
self.data = None self.data = None
self.state = None self.state = None
self.hourly_rate = None self.hourly_rate = None
self._tag = None
self._start_time = None
# Always release strong reference to session (to allow garbage collection) # Always release strong reference to session (to allow garbage collection)
self._active_sessions_[self.guildid].pop(self.userid) self._active_sessions_[self.guildid].pop(self.userid)

View File

@@ -34,7 +34,7 @@ _p = babel._p
class VoiceTrackerSettings(SettingGroup): class VoiceTrackerSettings(SettingGroup):
class UntrackedChannels(ListData, ChannelListSetting): class UntrackedChannels(ListData, ChannelListSetting):
setting_id = 'untracked_channels' setting_id = 'untracked_channels'
_event = 'guild_setting_update_untracked_channels' _event = 'guildset_untracked_channels'
_set_cmd = 'configure voice_rewards' _set_cmd = 'configure voice_rewards'
_display_name = _p('guildset:untracked_channels', "untracked_channels") _display_name = _p('guildset:untracked_channels', "untracked_channels")
@@ -111,7 +111,7 @@ class VoiceTrackerSettings(SettingGroup):
class HourlyReward(ModelData, IntegerSetting): class HourlyReward(ModelData, IntegerSetting):
setting_id = 'hourly_reward' setting_id = 'hourly_reward'
_event = 'guild_setting_update_hourly_reward' _event = 'on_guildset_hourly_reward'
_set_cmd = 'configure voice_rewards' _set_cmd = 'configure voice_rewards'
_display_name = _p('guildset:hourly_reward', "hourly_reward") _display_name = _p('guildset:hourly_reward', "hourly_reward")
@@ -191,7 +191,7 @@ class VoiceTrackerSettings(SettingGroup):
Guild setting describing the per-hour LionCoin bonus given to "live" members during tracking. Guild setting describing the per-hour LionCoin bonus given to "live" members during tracking.
""" """
setting_id = 'hourly_live_bonus' setting_id = 'hourly_live_bonus'
_event = 'guild_setting_update_hourly_live_bonus' _event = 'on_guildset_hourly_live_bonus'
_set_cmd = 'configure voice_rewards' _set_cmd = 'configure voice_rewards'
_display_name = _p('guildset:hourly_live_bonus', "hourly_live_bonus") _display_name = _p('guildset:hourly_live_bonus', "hourly_live_bonus")
@@ -242,7 +242,7 @@ class VoiceTrackerSettings(SettingGroup):
class DailyVoiceCap(ModelData, DurationSetting): class DailyVoiceCap(ModelData, DurationSetting):
setting_id = 'daily_voice_cap' setting_id = 'daily_voice_cap'
_event = 'guild_setting_update_daily_voice_cap' _event = 'on_guildset_daily_voice_cap'
_set_cmd = 'configure voice_rewards' _set_cmd = 'configure voice_rewards'
_display_name = _p('guildset:daily_voice_cap', "daily_voice_cap") _display_name = _p('guildset:daily_voice_cap', "daily_voice_cap")

View File

@@ -20,6 +20,7 @@ class MetaUtils(LionCog):
'cmd:page|desc', 'cmd:page|desc',
"Jump to a given page of the ouput of a previous command in this channel." "Jump to a given page of the ouput of a previous command in this channel."
), ),
with_app_command=False
) )
async def page_group(self, ctx: LionContext): async def page_group(self, ctx: LionContext):
""" """

View File

@@ -765,7 +765,7 @@ class Timezoned:
Return the start of the current month in the object's timezone Return the start of the current month in the object's timezone
""" """
today = self.today today = self.today
return today - datetime.timedelta(days=(today.day - 1)) return today.replace(day=1)
def replace_multiple(format_string, mapping): def replace_multiple(format_string, mapping):

View File

@@ -32,7 +32,7 @@ class TaskMonitor(Generic[Taskid]):
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
self._wakeup: asyncio.Event = asyncio.Event() self._wakeup: asyncio.Event = asyncio.Event()
self._monitor_task: Optional[self.Task] = None self._monitor_task: Optional[asyncio.Task] = None
# Task data # Task data
self._tasklist: list[Taskid] = [] self._tasklist: list[Taskid] = []
@@ -42,6 +42,19 @@ class TaskMonitor(Generic[Taskid]):
# And allows simpler external cancellation if required # And allows simpler external cancellation if required
self._running: dict[Taskid, asyncio.Future] = {} self._running: dict[Taskid, asyncio.Future] = {}
def __repr__(self):
return (
"<"
f"{self.__class__.__name__}"
f" tasklist={len(self._tasklist)}"
f" taskmap={len(self._taskmap)}"
f" wakeup={self._wakeup.is_set()}"
f" bucket={self._bucket}"
f" running={len(self._running)}"
f" task={self._monitor_task}"
f">"
)
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None: def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
""" """
Similar to `schedule_tasks`, but wipe and reset the tasklist. Similar to `schedule_tasks`, but wipe and reset the tasklist.

View File

@@ -69,12 +69,12 @@ class DurationTransformer(Transformer):
name=t(_p( name=t(_p(
'util:Duration|acmpl|error', 'util:Duration|acmpl|error',
"Cannot extract duration from \"{partial}\"" "Cannot extract duration from \"{partial}\""
)).format(partial=partial), )).format(partial=partial)[:100],
value=partial value=partial
) )
else: else:
choice = appcmds.Choice( choice = appcmds.Choice(
name=strfdur(duration, short=False, show_days=True), name=strfdur(duration, short=False, show_days=True)[:100],
value=partial value=partial
) )
return [choice] return [choice]

View File

@@ -307,17 +307,17 @@ class Pager(BasePager):
"Current: Page {page}/{total}" "Current: Page {page}/{total}"
)).format(page=num+1, total=total) )).format(page=num+1, total=total)
choices = [ choices = [
appcmds.Choice(name=string, value=str(num+1)) appcmds.Choice(name=string[:100], value=str(num+1))
for num, string in sorted(page_choices.items(), key=lambda t: t[0]) for num, string in sorted(page_choices.items(), key=lambda t: t[0])
] ]
else: else:
# Particularly support page names here # Particularly support page names here
choices = [ choices = [
appcmds.Choice( appcmds.Choice(
name='> ' * (i == num) + t(_p( name=('> ' * (i == num) + t(_p(
'cmd:page|acmpl|pager:Pager|choice:general', 'cmd:page|acmpl|pager:Pager|choice:general',
"Page {page}" "Page {page}"
)).format(page=i+1), )).format(page=i+1))[:100],
value=str(i+1) value=str(i+1)
) )
for i in range(0, total) for i in range(0, total)
@@ -351,7 +351,7 @@ class Pager(BasePager):
name=t(_p( name=t(_p(
'cmd:page|acmpl|pager:Page|choice:select', 'cmd:page|acmpl|pager:Page|choice:select',
"Selected: Page {page}/{total}" "Selected: Page {page}/{total}"
)).format(page=page_num+1, total=total), )).format(page=page_num+1, total=total)[:100],
value=str(page_num + 1) value=str(page_num + 1)
) )
return [choice, *choices] return [choice, *choices]
@@ -361,7 +361,7 @@ class Pager(BasePager):
name=t(_p( name=t(_p(
'cmd:page|acmpl|pager:Page|error:parse', 'cmd:page|acmpl|pager:Page|error:parse',
"No matching pages!" "No matching pages!"
)).format(page=page_num, total=total), )).format(page=page_num, total=total)[:100],
value=partial value=partial
) )
] ]