From eaa44ab43cbc1fc95a29c3566065439714146822 Mon Sep 17 00:00:00 2001 From: Conatum Date: Fri, 6 Oct 2023 01:51:41 +0300 Subject: [PATCH] (voice): Rewrite initialise and refresh mechanism. --- src/meta/LionBot.py | 2 +- src/modules/member_admin/cog.py | 3 +- src/modules/pomodoro/cog.py | 2 +- src/modules/rooms/cog.py | 8 +- src/modules/statistics/data.py | 8 +- src/settings/ui.py | 2 +- src/tracking/voice/cog.py | 606 ++++++++++++++------------------ src/tracking/voice/data.py | 4 +- src/tracking/voice/session.py | 21 +- src/tracking/voice/settings.py | 8 +- src/utils/lib.py | 2 +- 11 files changed, 297 insertions(+), 369 deletions(-) diff --git a/src/meta/LionBot.py b/src/meta/LionBot.py index ca699ffb..7ee77493 100644 --- a/src/meta/LionBot.py +++ b/src/meta/LionBot.py @@ -46,7 +46,7 @@ class LionBot(Bot): # self.appdata = appdata self.config = config self.app_ipc = app_ipc - self.core: Optional['CoreCog'] = None + self.core: 'CoreCog' = None self.translator = translator self.system_monitor = SystemMonitor() diff --git a/src/modules/member_admin/cog.py b/src/modules/member_admin/cog.py index 6887ed6d..a707a3ea 100644 --- a/src/modules/member_admin/cog.py +++ b/src/modules/member_admin/cog.py @@ -227,7 +227,8 @@ class MemberAdminCog(LionCog): logger.info(f"Cleared persisting roles for guild because we left the guild.") @LionCog.listener('on_guildset_role_persistence') - async def clear_stored_roles(self, guildid, data): + async def clear_stored_roles(self, guildid, setting: MemberAdminSettings.RolePersistence): + data = setting.data if data is False: await self.data.past_roles.delete_where(guildid=guildid) logger.info( diff --git a/src/modules/pomodoro/cog.py b/src/modules/pomodoro/cog.py index 12a6b3ae..e4e659de 100644 --- a/src/modules/pomodoro/cog.py +++ b/src/modules/pomodoro/cog.py @@ -343,7 +343,7 @@ class TimerCog(LionCog): @LionCog.listener('on_guildset_pomodoro_channel') @log_wrap(action='Update Pomodoro Channels') - async def _update_pomodoro_channels(self, guildid: int, data: Optional[int]): + async def _update_pomodoro_channels(self, guildid: int, setting: TimerSettings.PomodoroChannel): """ Request a send_status for all guild timers which need to move channel. """ diff --git a/src/modules/rooms/cog.py b/src/modules/rooms/cog.py index e4543cbd..4b2a6d70 100644 --- a/src/modules/rooms/cog.py +++ b/src/modules/rooms/cog.py @@ -173,14 +173,15 @@ class RoomCog(LionCog): # Setting event handlers @LionCog.listener('on_guildset_rooms_category') @log_wrap(action='Update Rooms Category') - async def _update_rooms_category(self, guildid: int, data: Optional[int]): + async def _update_rooms_category(self, guildid: int, setting: RoomSettings.Category): """ Move all active private channels to the new category. This shouldn't affect the channel function at all. """ + data = setting.data guild = self.bot.get_guild(guildid) - new_category = guild.get_channel(data) if guild else None + new_category = guild.get_channel(data) if guild and data else None if new_category: tasks = [] for room in list(self._room_cache[guildid].values()): @@ -196,10 +197,11 @@ class RoomCog(LionCog): @LionCog.listener('on_guildset_rooms_visible') @log_wrap(action='Update Rooms Visibility') - async def _update_rooms_visibility(self, guildid: int, data: bool): + async def _update_rooms_visibility(self, guildid: int, setting: RoomSettings.Visible): """ Update the everyone override on each room to reflect the new setting. """ + data = setting.data tasks = [] for room in list(self._room_cache[guildid].values()): if room.channel: diff --git a/src/modules/statistics/data.py b/src/modules/statistics/data.py index 76ba4cf6..42062254 100644 --- a/src/modules/statistics/data.py +++ b/src/modules/statistics/data.py @@ -122,7 +122,7 @@ class StatsData(Registry): "SELECT study_time_between(%s, %s, %s, %s)", (guildid, userid, _start, _end) ) - return (await cursor.fetchone()[0]) or 0 + return (await cursor.fetchone())[0] or 0 @classmethod @log_wrap(action='study_times_between') @@ -162,11 +162,11 @@ class StatsData(Registry): "SELECT study_time_since(%s, %s, %s)", (guildid, userid, _start) ) - return (await cursor.fetchone()[0]) or 0 + return (await cursor.fetchone())[0] or 0 @classmethod @log_wrap(action='study_times_since') - async def study_times_since(cls, guildid: Optional[int], userid: int, *starts) -> int: + async def study_times_since(cls, guildid: Optional[int], userid: int, *starts) -> list[int]: if len(starts) < 1: raise ValueError('No starting points given!') @@ -251,7 +251,7 @@ class StatsData(Registry): return leaderboard @classmethod - @log_wrap('leaderboard_all') + @log_wrap(action='leaderboard_all') async def leaderboard_all(cls, guildid: int): """ Return the all-time voice totals for the given guild. diff --git a/src/settings/ui.py b/src/settings/ui.py index 719fb7e1..e53a874c 100644 --- a/src/settings/ui.py +++ b/src/settings/ui.py @@ -236,7 +236,7 @@ class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]): Callable[[ParentID, SettingData], Coroutine[Any, Any, None]] """ if self._event is not None and (bot := ctx_bot.get()) is not None: - bot.dispatch(self._event, self.parent_id, self.data) + bot.dispatch(self._event, self.parent_id, self) def get_listener(self, key): return self._listeners_.get(key, None) diff --git a/src/tracking/voice/cog.py b/src/tracking/voice/cog.py index d5ecb3c8..7dbd82fe 100644 --- a/src/tracking/voice/cog.py +++ b/src/tracking/voice/cog.py @@ -1,17 +1,16 @@ from typing import Optional import asyncio import datetime as dt -from collections import defaultdict import discord from discord.ext import commands as cmds from discord import app_commands as appcmds +from data import Condition from meta import LionBot, LionCog, LionContext -from meta.errors import UserInputError -from meta.logger import log_wrap, logging_context +from meta.logger import log_wrap from meta.sharding import THIS_SHARD -from utils.lib import utc_now, error_embed +from utils.lib import utc_now from core.lion_guild import VoiceMode from wards import low_management_ward, moderator_ctxward @@ -44,6 +43,8 @@ class VoiceTrackerCog(LionCog): self.untracked_channels = self.settings.UntrackedChannels._cache + self.active_sessions = VoiceSession._active_sessions_ + async def cog_load(self): await self.data.init() @@ -71,7 +72,8 @@ class VoiceTrackerCog(LionCog): # Simultaneously! ... - def get_session(self, guildid, userid, **kwargs) -> VoiceSession: + # ----- Cog API ----- + def get_session(self, guildid, userid, **kwargs) -> Optional[VoiceSession]: """ Get the VoiceSession for the given member. @@ -91,6 +93,197 @@ class VoiceTrackerCog(LionCog): untracked = False return untracked + @log_wrap(action='load sessions') + async def _load_sessions(self, + states: dict[tuple[int, int], TrackedVoiceState], + ongoing: list[VoiceTrackerData.VoiceSessionsOngoing]): + """ + Load voice sessions from provided states and ongoing data. + + Provided data may cross multiple guilds. + Assumes all states which do not have data should be started. + Assumes all ongoing data which does not have states should be ended. + Assumes untracked channel data is up to date. + """ + OngoingData = VoiceTrackerData.VoiceSessionsOngoing + + # Compute time to end complete sessions + now = utc_now() + last_update = max((row.last_update for row in ongoing), default=now) + end_at = min(last_update + dt.timedelta(seconds=3600), now) + + # Bulk fetches for voice-active members and guilds + active_memberids = list(states.keys()) + active_guildids = set(gid for gid, _ in states) + + if states: + lguilds = await self.bot.core.lions.fetch_guilds(*active_guildids) + await self.bot.core.lions.fetch_members(*active_memberids) + tracked_today_data = await self.data.VoiceSessions.multiple_voice_tracked_since( + *((guildid, userid, lguilds[guildid].today) for guildid, userid in active_memberids) + ) + tracked_today = {(row['guildid'], row['userid']): row['tracked'] for row in tracked_today_data} + else: + lguilds = {} + tracked_today = {} + + # Zip session information together by memberid keys + sessions: dict[tuple[int, int], tuple[Optional[TrackedVoiceState], Optional[OngoingData]]] = {} + for row in ongoing: + key = (row.guildid, row.userid) + sessions[key] = (states.pop(key, None), row) + for key, state in states.items(): + sessions[key] = (state, None) + + # Now split up session information to fill action maps + close_ongoing = [] + update_ongoing = [] + create_ongoing = [] + expiries = {} + load_sessions = [] + schedule_sessions = {} + + for (gid, uid), (state, data) in sessions.items(): + if state is not None: + # Member is active + if data is not None and data.channelid != state.channelid: + # Ongoing session does not match active state + # Close the session, but still create/schedule the state + close_ongoing.append((gid, uid, end_at)) + data = None + + # Now create/update/schedule active session + # Also create/update data if required + lguild = lguilds[gid] + tomorrow = lguild.today + dt.timedelta(days=1) + cap = lguild.config.get('daily_voice_cap').value + tracked = tracked_today[gid, uid] + hourly_rate = await self._calculate_rate(gid, uid, state) + + if tracked >= cap: + # Active session is already over cap + # Stop ongoing if it exists, and schedule next session start + delay = (tomorrow - now).total_seconds() + start_time = tomorrow + expiry = tomorrow + dt.timedelta(seconds=cap) + schedule_sessions[(gid, uid)] = (delay, start_time, expiry, state, hourly_rate) + if data is not None: + close_ongoing.append(( + gid, uid, + max(now - dt.timedelta(seconds=tracked - cap), data.last_update) + )) + else: + # Active session, update/create data + expiry = now + dt.timedelta(seconds=(cap - tracked)) + if expiry > tomorrow: + expiry = tomorrow + dt.timedelta(seconds=cap) + expiries[(gid, uid)] = expiry + if data is not None: + update_ongoing.append((gid, uid, now, state.stream, state.video, hourly_rate)) + else: + create_ongoing.append(( + gid, uid, state.channelid, now, now, state.stream, state.video, hourly_rate + )) + elif data is not None: + # Ongoing data has no state, close the session + close_ongoing.append((gid, uid, end_at)) + + # Close data that needs closing + if close_ongoing: + logger.info( + f"Ending {len(close_ongoing)} ongoing voice sessions with no matching voice state." + ) + await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*close_ongoing) + + # Update data that needs updating + if update_ongoing: + logger.info( + f"Continuing {len(update_ongoing)} ongoing voice sessions with matching voice state." + ) + rows = await self.data.VoiceSessionsOngoing.update_voice_sessions_at(*update_ongoing) + load_sessions.extend(rows) + + # Create data that needs creating + if create_ongoing: + logger.info( + f"Creating {len(create_ongoing)} voice sessions from new voice states." + ) + # First ensure the tracked channels exist + cids = set((item[2], item[0]) for item in create_ongoing) + await self.data.TrackedChannel.fetch_multiple(*cids) + + # Then create the sessions + rows = await self.data.VoiceSessionsOngoing.table.insert_many( + ('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream', + 'live_video', 'hourly_coins'), + *create_ongoing + ).with_adapter(self.data.VoiceSessionsOngoing._make_rows) + load_sessions.extend(rows) + + # Create sessions from ongoing, with expiry + for row in load_sessions: + VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)]) + + # Schedule starting sessions + for (gid, uid), args in schedule_sessions.items(): + session = VoiceSession.get(self.bot, gid, uid) + await session.schedule_start(*args) + + logger.info( + f"Successfully loaded {len(load_sessions)} and scheduled {len(schedule_sessions)} voice sessions." + ) + + @log_wrap(action='refresh guild sessions') + async def refresh_guild_sessions(self, guild: discord.Guild): + """ + Idempotently refresh all guild voice sessions in the given guild. + + Essentially a lighter version of `initialise`. + """ + # TODO: There is a very small potential window for a race condition here + # Since we do not have a version of 'handle_events' for the guild + # We may actually handle events before starting refresh + # Causing sessions to have invalid state. + # If this becomes an actual problem, implement an `ignore_guilds` set flag of some form... + logger.debug(f"Beginning voice state refresh for ") + + async with self.tracking_lock: + # TODO: Add a 'lock holder' attribute which is readable by the monitor + logger.debug(f"Voice state refresh for is past lock") + + # Deactivate any ongoing session tasks in this guild + active = self.active_sessions.pop(guild.id, {}).values() + for session in active: + session.cancel() + + # Update untracked channel information for this guild + self.untracked_channels.pop(guild.id, None) + await self.settings.UntrackedChannels.get(guild.id) + + # Read tracked voice states + states = {} + for channel in guild.voice_channels: + if not self.is_untracked(channel): + for member in channel.members: + if member.voice and not member.bot: + state = TrackedVoiceState.from_voice_state(member.voice) + states[(guild.id, member.id)] = state + logger.debug(f"Loaded {len(states)} tracked voice states for .") + + # Read ongoing session data + ongoing = await self.data.VoiceSessionsOngoing.fetch_where(guildid=guild.id) + logger.debug( + f"Loaded {len(ongoing)} ongoing voice sessions from data for . Beginning reload." + ) + + await self._load_sessions(states, ongoing) + logger.info( + f"Completed guild voice session reload for " + f"with '{len(self.active_sessions[guild.id])}' active sessions." + ) + + + # ----- Event Handlers ----- @LionCog.listener('on_ready') @log_wrap(action='Init Voice Sessions') async def initialise(self): @@ -99,192 +292,54 @@ class VoiceTrackerCog(LionCog): Ends ongoing sessions for members who are not in the given voice channel. """ - # First take the tracking lock - # Ensures current event handling completes before re-initialisation + logger.info("Beginning voice session state initialisation. Disabling voice event handling.") + # If `on_ready` is called, that means we are initialising + # or we missed events and need to re-initialise. + # Start ignoring events because they may be working on stale or partial state + self.handle_events = False + + # Services which read our cache should wait for initialisation before taking the lock + self.initialised.clear() + + # Wait for running events to complete + # And make sure future events will be processed after initialisation + # Note only events occurring after our voice state snapshot will be processed async with self.tracking_lock: - logger.info("Reloading ongoing voice sessions") + # Deactivate all ongoing sessions + active = [session for gsessions in self.active_sessions.values() for session in gsessions.values()] + for session in active: + session.cancel() + self.active_sessions.clear() + + # Also clear the session registry cache + VoiceSession._sessions_.clear() + + # Refresh untracked information for all guilds we are in + await self.settings.UntrackedChannels.setup(self.bot) - logger.debug("Disabling voice state event handling.") - self.handle_events = False - self.initialised.clear() # Read and save the tracked voice states of all visible voice channels - voice_members = {} # (guildid, userid) -> TrackedVoiceState - voice_guilds = set() + states = {} for guild in self.bot.guilds: - untracked = self.untracked_channels.get(guild.id, ()) for channel in guild.voice_channels: - if channel.id in untracked: - continue - if channel.category_id and channel.category_id in untracked: - continue + if not self.is_untracked(channel): + for member in channel.members: + if member.voice and not member.bot: + state = TrackedVoiceState.from_voice_state(member.voice) + states[(guild.id, member.id)] = state - for member in channel.members: - if member.bot: - continue - voice_members[(guild.id, member.id)] = TrackedVoiceState.from_voice_state(member.voice) - voice_guilds.add(guild.id) - - logger.debug(f"Cached {len(voice_members)} members from voice channels.") + logger.info( + f"Saved voice snapshot with {len(states)} tracked states. Re-enabling voice event handling." + ) self.handle_events = True - logger.debug("Re-enabled voice state event handling.") - # Iterate through members with current ongoing sessions - # End or update sessions as needed, based on saved tracked state - ongoing_rows = await self.data.VoiceSessionsOngoing.fetch_where( - guildid=[guild.id for guild in self.bot.guilds] + # Load ongoing session data for the entire shard + ongoing = await self.data.VoiceSessionsOngoing.fetch_where(THIS_SHARD) + logger.info( + f"Retrieved {len(ongoing)} ongoing voice sessions from data. Beginning reload." ) - logger.debug( - f"Loaded {len(ongoing_rows)} ongoing sessions from data. Splitting into complete and incomplete." - ) - complete = [] - incomplete = [] - incomplete_guildids = set() - # Compute time to end complete sessions - now = utc_now() - last_update = max((row.last_update for row in ongoing_rows), default=now) - end_at = min(last_update + dt.timedelta(seconds=3600), now) + await self._load_sessions(states, ongoing) - for row in ongoing_rows: - key = (row.guildid, row.userid) - state = voice_members.get(key, None) - untracked = self.untracked_channels.get(row.guildid, []) - if ( - state - and state.channelid == row.channelid - and state.channelid not in untracked - and (ch := self.bot.get_channel(state.channelid)) is not None - and (not ch.category_id or ch.category_id not in untracked) - ): - # Mark session as ongoing - incomplete.append((row, state)) - incomplete_guildids.add(row.guildid) - voice_members.pop(key) - else: - # Mark session as complete - complete.append((row.guildid, row.userid, end_at)) - - # Load required guild data into cache - active_guildids = incomplete_guildids.union(voice_guilds) - if active_guildids: - await self.bot.core.data.Guild.fetch_where(guildid=tuple(active_guildids)) - lguilds = {guildid: await self.bot.core.lions.fetch_guild(guildid) for guildid in active_guildids} - - # Calculate tracked_today for members with ongoing sessions - active_members = set((row.guildid, row.userid) for row, _ in incomplete) - active_members.update(voice_members.keys()) - if active_members: - tracked_today_data = await self.data.VoiceSessions.multiple_voice_tracked_since( - *((guildid, userid, lguilds[guildid].today) for guildid, userid in active_members) - ) - else: - tracked_today_data = [] - tracked_today = {(row['guildid'], row['userid']): row['tracked'] for row in tracked_today_data} - - if incomplete: - # Note that study_time_since _includes_ ongoing sessions in its calculation - # So expiry times are "time left today until cap" or "tomorrow + cap" - to_load = [] # (session_data, expiry_time) - to_update = [] # (guildid, userid, update_at, stream, video, hourly_rate) - for session_data, state in incomplete: - # Calculate expiry times - lguild = lguilds[session_data.guildid] - cap = lguild.config.get('daily_voice_cap').value - tracked = tracked_today[(session_data.guildid, session_data.userid)] - if tracked >= cap: - # Already over cap - complete.append(( - session_data.guildid, - session_data.userid, - max(now + dt.timedelta(seconds=tracked - cap), session_data.last_update) - )) - else: - tomorrow = lguild.today + dt.timedelta(days=1) - expiry = now + dt.timedelta(seconds=(cap - tracked)) - if expiry > tomorrow: - expiry = tomorrow + dt.timedelta(seconds=cap) - to_load.append((session_data, expiry)) - - # TODO: Probably better to do this by batch - # Could force all bonus calculators to accept list of members - hourly_rate = await self._calculate_rate(session_data.guildid, session_data.userid, state) - to_update.append(( - session_data.guildid, - session_data.userid, - now, - state.stream, - state.video, - hourly_rate - )) - # Run the updates, note that session_data uses registry pattern so will also update - if to_update: - await self.data.VoiceSessionsOngoing.update_voice_sessions_at(*to_update) - - # Load the sessions - for data, expiry in to_load: - VoiceSession.from_ongoing(self.bot, data, expiry) - - logger.info(f"Resumed {len(to_load)} ongoing voice sessions.") - - if complete: - logger.info(f"Ending {len(complete)} out-of-date or expired study sessions.") - - # Complete sessions just need a mass end_voice_session_at() - await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*complete) - - # Then iterate through the saved states from tracked voice channels - # Start sessions if they don't already exist - if voice_members: - expiries = {} # (guildid, memberid) -> expiry time - to_create = [] # (guildid, userid, channelid, start_time, last_update, live_stream, live_video, rate) - for (guildid, userid), state in voice_members.items(): - untracked = self.untracked_channels.get(guildid, []) - channel = self.bot.get_channel(state.channelid) - if ( - channel - and channel.id not in untracked - and (not channel.category_id or channel.category_id not in untracked) - ): - # State is from member in tracked voice channel - # Calculate expiry - lguild = lguilds[guildid] - cap = lguild.config.get('daily_voice_cap').value - tracked = tracked_today[(guildid, userid)] - if tracked < cap: - tomorrow = lguild.today + dt.timedelta(days=1) - expiry = now + dt.timedelta(seconds=(cap - tracked)) - if expiry > tomorrow: - expiry = tomorrow + dt.timedelta(seconds=cap) - expiries[(guildid, userid)] = expiry - - hourly_rate = await self._calculate_rate(guildid, userid, state) - to_create.append(( - guildid, userid, - state.channelid, - now, now, - state.stream, state.video, - hourly_rate - )) - # Bulk create the ongoing sessions - if to_create: - # First ensure the lion members exist - await self.bot.core.lions.fetch_members( - *(item[:2] for item in to_create) - ) - - # Then ensure the TrackedChannels exist - cids = set((item[2], item[0]) for item in to_create) - await self.data.TrackedChannel.fetch_multiple(*cids) - - # Then actually create the ongoing sessions - rows = await self.data.VoiceSessionsOngoing.table.insert_many( - ('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream', - 'live_video', 'hourly_coins'), - *to_create - ).with_adapter(self.data.VoiceSessionsOngoing._make_rows) - for row in rows: - VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)]) - logger.info(f"Started {len(rows)} new voice sessions from voice channels!") self.initialised.set() @LionCog.listener("on_voice_state_update") @@ -391,116 +446,24 @@ class VoiceTrackerCog(LionCog): hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate) await session.update(new_state=astate, new_rate=hourly_rate) - @LionCog.listener("on_guild_setting_update_untracked_channels") - async def update_untracked_channels(self, guildid, setting): - """ - Close sessions in untracked channels, and recalculate previously untracked sessions - """ + @LionCog.listener("on_guildset_untracked_channels") + @LionCog.listener("on_guildset_hourly_reward") + @LionCog.listener("on_guildset_hourly_live_bonus") + @LionCog.listener("on_guildset_daily_voice_cap") + @LionCog.listener("on_guildset_timezone") + async def _event_refresh_guild(self, guildid: int, setting): if not self.handle_events: return - - async with self.tracking_lock: - lguild = await self.bot.core.lions.fetch_guild(guildid) - guild = self.bot.get_guild(guildid) - if not guild: - # Left guild while waiting on lock - return - cap = lguild.config.get('daily_voice_cap').value - untracked = self.untracked_channels.get(guildid, []) - now = utc_now() - - # Iterate through active sessions, close any that are in untracked channels - active = VoiceSession._active_sessions_.get(guildid, {}) - for session in list(active.values()): - if session.state.channelid in untracked: - await session.close() - - # Iterate through voice members, open new sessions if needed - expiries = {} - to_create = [] - for channel in guild.voice_channels: - if channel.id in untracked: - continue - for member in channel.members: - if self.get_session(guildid, member.id).activity: - # Already have an active session for this member - continue - userid = member.id - state = TrackedVoiceState.from_voice_state(member.voice) - - # TODO: Take into account tracked_today time? - # TODO: Make a per-guild refresh function to stay DRY - tomorrow = lguild.today + dt.timedelta(days=1) - expiry = now + dt.timedelta(seconds=cap) - if expiry > tomorrow: - expiry = tomorrow + dt.timedelta(seconds=cap) - expiries[(guildid, userid)] = expiry - - hourly_rate = await self._calculate_rate(guildid, userid, state) - to_create.append(( - guildid, userid, - state.channelid, - now, now, - state.stream, state.video, - hourly_rate - )) - - if to_create: - # Ensure LionMembers exist - await self.bot.core.lions.fetch_members( - *(item[:2] for item in to_create) - ) - - # Ensure TrackedChannels exist - cids = set((item[2], item[0]) for item in to_create) - await self.data.TrackedChannel.fetch_multiple(*cids) - - # Create new sessions - rows = await self.data.VoiceSessionsOngoing.table.insert_many( - ('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream', - 'live_video', 'hourly_coins'), - *to_create - ).with_adapter(self.data.VoiceSessionsOngoing._make_rows) - for row in rows: - VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)]) - logger.info( - f"Started {len(rows)} new voice sessions from voice members " - f"in previously untracked channels of guild '{guild.name}' ." - ) - - @LionCog.listener("on_guild_setting_update_hourly_reward") - async def update_hourly_reward(self, guildid, setting): - if not self.handle_events: - return - - async with self.tracking_lock: - sessions = VoiceSession._active_sessions_.get(guildid, {}) - for session in list(sessions.values()): - hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state) - await session.update(new_rate=hourly_rate) - - @LionCog.listener("on_guild_setting_update_hourly_live_bonus") - async def update_hourly_live_bonus(self, guildid, setting): - if not self.handle_events: - return - - async with self.tracking_lock: - sessions = VoiceSession._active_sessions_.get(guildid, {}) - for session in list(sessions.values()): - hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state) - await session.update(new_rate=hourly_rate) - - @LionCog.listener("on_guild_setting_update_daily_voice_cap") - async def update_daily_voice_cap(self, guildid, setting): - # TODO: Guild daily_voice_cap setting triggers session expiry recalculation for all sessions - ... - - @LionCog.listener("on_guild_setting_update_timezone") - @log_wrap(action='Voice Track') - @log_wrap(action='Timezone Update') - async def update_timezone(self, guildid, setting): - # TODO: Guild timezone setting triggers studied_today cache rebuild - logger.info("Received dispatch event for timezone change!") + guild = self.bot.get_guild(guildid) + if guild is None: + logger.warning( + f"Voice tracker discarding '{setting.setting_id}' event for unknown guild ." + ) + else: + logger.debug( + f"Voice tracker handling '{setting.setting_id}' event for guild ." + ) + await self.refresh_guild_sessions(guild) async def _calculate_rate(self, guildid, userid, state): """ @@ -522,7 +485,7 @@ class VoiceTrackerCog(LionCog): return hourly_rate - async def _session_boundaries_for(self, guildid: int, userid: int) -> tuple[int, dt.datetime, dt.datetime]: + async def _session_boundaries_for(self, guildid: int, userid: int) -> tuple[float, dt.datetime, dt.datetime]: """ Compute when the next session for this member should start and expire. @@ -539,7 +502,7 @@ class VoiceTrackerCog(LionCog): """ lguild = await self.bot.core.lions.fetch_guild(guildid) now = lguild.now - tomorrow = now + dt.timedelta(days=1) + tomorrow = lguild.today + dt.timedelta(days=1) studied_today = await self.fetch_tracked_today(guildid, userid) cap = lguild.config.get('daily_voice_cap').value @@ -552,7 +515,7 @@ class VoiceTrackerCog(LionCog): delay = 20 expiry = start_time + dt.timedelta(seconds=cap) - if expiry >= tomorrow: + if expiry > tomorrow: expiry = tomorrow + dt.timedelta(seconds=cap) return (delay, start_time, expiry) @@ -574,61 +537,9 @@ class VoiceTrackerCog(LionCog): Initialise and start required new sessions from voice channel members when we join a guild. """ if not self.handle_events: + # Initialisation will take care of it for us return - - async with self.tracking_lock: - guildid = guild.id - lguild = await self.bot.core.lions.fetch_guild(guildid) - cap = lguild.config.get('daily_voice_cap').value - untracked = self.untracked_channels.get(guildid, []) - now = utc_now() - - expiries = {} - to_create = [] - for channel in guild.voice_channels: - if channel.id in untracked: - continue - for member in channel.members: - userid = member.id - state = TrackedVoiceState.from_voice_state(member.voice) - - tomorrow = lguild.today + dt.timedelta(days=1) - expiry = now + dt.timedelta(seconds=cap) - if expiry > tomorrow: - expiry = tomorrow + dt.timedelta(seconds=cap) - expiries[(guildid, userid)] = expiry - - hourly_rate = await self._calculate_rate(guildid, userid, state) - to_create.append(( - guildid, userid, - state.channelid, - now, now, - state.stream, state.video, - hourly_rate - )) - - if to_create: - # Ensure LionMembers exist - await self.bot.core.lions.fetch_members( - *(item[:2] for item in to_create) - ) - - # Ensure TrackedChannels exist - cids = set((item[2], item[0]) for item in to_create) - await self.data.TrackedChannel.fetch_multiple(*cids) - - # Create new sessions - rows = await self.data.VoiceSessionsOngoing.table.insert_many( - ('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream', - 'live_video', 'hourly_coins'), - *to_create - ).with_adapter(self.data.VoiceSessionsOngoing._make_rows) - for row in rows: - VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)]) - logger.info( - f"Started {len(rows)} new voice sessions from voice members " - f"in new guild '{guild.name}' ." - ) + await self.refresh_guild_sessions(guild) @LionCog.listener("on_guild_remove") @log_wrap(action='Leave Guild Voice Sessions') @@ -645,10 +556,7 @@ class VoiceTrackerCog(LionCog): now = utc_now() to_close = [] # (guildid, userid, _at) for session in sessions.values(): - if session.start_task is not None: - session.start_task.cancel() - if session.expiry_task is not None: - session.expiry_task.cancel() + session.cancel() to_close.append((session.guildid, session.userid, now)) if to_close: await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*to_close) diff --git a/src/tracking/voice/data.py b/src/tracking/voice/data.py index c003a4a2..86c5e500 100644 --- a/src/tracking/voice/data.py +++ b/src/tracking/voice/data.py @@ -108,7 +108,7 @@ class VoiceTrackerData(Registry): video_duration = Integer() stream_duration = Integer() coins_earned = Integer() - last_update = Integer() + last_update = Timestamp() live_stream = Bool() live_video = Bool() hourly_coins = Integer() @@ -154,7 +154,7 @@ class VoiceTrackerData(Registry): async def update_voice_session_at( cls, guildid: int, userid: int, _at: dt.datetime, stream: bool, video: bool, rate: float - ) -> int: + ): async with cls._connector.connection() as conn: async with conn.cursor() as cursor: await cursor.execute( diff --git a/src/tracking/voice/session.py b/src/tracking/voice/session.py index 5f5766aa..fe018581 100644 --- a/src/tracking/voice/session.py +++ b/src/tracking/voice/session.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, overload, Literal from enum import IntEnum from collections import defaultdict import datetime as dt @@ -96,6 +96,13 @@ class VoiceSession: self._tag = None self._start_time = None + def cancel(self): + if self.start_task is not None: + self.start_task.cancel() + if self.expiry_task is not None: + self.expiry_task.cancel() + self._active_sessions_[self.guildid].pop(self.userid, None) + @property def tag(self) -> Optional[str]: if self.data: @@ -121,6 +128,16 @@ class VoiceSession: else: return SessionState.INACTIVE + @overload + @classmethod + def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[False]) -> Optional['VoiceSession']: + ... + + @overload + @classmethod + def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[True] = True) -> 'VoiceSession': + ... + @classmethod def get(cls, bot: LionBot, guildid: int, userid: int, create=True) -> Optional['VoiceSession']: """ @@ -167,6 +184,7 @@ class VoiceSession: self.start_task = asyncio.create_task(self._start_after(delay, start_time)) self.schedule_expiry(expire_time) + self._active_sessions_[self.guildid][self.userid] = self async def _start_after(self, delay: int, start_time: dt.datetime): """ @@ -174,7 +192,6 @@ class VoiceSession: Creates the tracked_channel if required. """ - self._active_sessions_[self.guildid][self.userid] = self await asyncio.sleep(delay) logger.debug( diff --git a/src/tracking/voice/settings.py b/src/tracking/voice/settings.py index 4f4d7387..a74bd541 100644 --- a/src/tracking/voice/settings.py +++ b/src/tracking/voice/settings.py @@ -34,7 +34,7 @@ _p = babel._p class VoiceTrackerSettings(SettingGroup): class UntrackedChannels(ListData, ChannelListSetting): setting_id = 'untracked_channels' - _event = 'guild_setting_update_untracked_channels' + _event = 'guildset_untracked_channels' _set_cmd = 'configure voice_rewards' _display_name = _p('guildset:untracked_channels', "untracked_channels") @@ -111,7 +111,7 @@ class VoiceTrackerSettings(SettingGroup): class HourlyReward(ModelData, IntegerSetting): setting_id = 'hourly_reward' - _event = 'guild_setting_update_hourly_reward' + _event = 'on_guildset_hourly_reward' _set_cmd = 'configure voice_rewards' _display_name = _p('guildset:hourly_reward', "hourly_reward") @@ -191,7 +191,7 @@ class VoiceTrackerSettings(SettingGroup): Guild setting describing the per-hour LionCoin bonus given to "live" members during tracking. """ setting_id = 'hourly_live_bonus' - _event = 'guild_setting_update_hourly_live_bonus' + _event = 'on_guildset_hourly_live_bonus' _set_cmd = 'configure voice_rewards' _display_name = _p('guildset:hourly_live_bonus', "hourly_live_bonus") @@ -242,7 +242,7 @@ class VoiceTrackerSettings(SettingGroup): class DailyVoiceCap(ModelData, DurationSetting): setting_id = 'daily_voice_cap' - _event = 'guild_setting_update_daily_voice_cap' + _event = 'on_guildset_daily_voice_cap' _set_cmd = 'configure voice_rewards' _display_name = _p('guildset:daily_voice_cap', "daily_voice_cap") diff --git a/src/utils/lib.py b/src/utils/lib.py index 10babc85..98d60c57 100644 --- a/src/utils/lib.py +++ b/src/utils/lib.py @@ -765,7 +765,7 @@ class Timezoned: Return the start of the current month in the object's timezone """ today = self.today - return today - datetime.timedelta(days=(today.day - 1)) + return today.replace(day=1) def replace_multiple(format_string, mapping):