From 5ac01d5cb208bacd52072be9795ca7a01b07659d Mon Sep 17 00:00:00 2001 From: Conatum Date: Sun, 8 Oct 2023 01:57:54 +0300 Subject: [PATCH] fix(voice): Avoid possible race condition. --- src/tracking/voice/session.py | 122 +++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 55 deletions(-) diff --git a/src/tracking/voice/session.py b/src/tracking/voice/session.py index fe018581..b75ad302 100644 --- a/src/tracking/voice/session.py +++ b/src/tracking/voice/session.py @@ -73,11 +73,14 @@ class VoiceSession: 'start_task', 'expiry_task', 'data', 'state', 'hourly_rate', '_tag', '_start_time', + 'lock', '__weakref__' ) _sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping - _active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions + + # Maintains strong references to active sessions + _active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict) def __init__(self, bot: LionBot, guildid: int, userid: int, data=None): self.bot = bot @@ -96,6 +99,10 @@ class VoiceSession: self._tag = None self._start_time = None + # Member session lock + # Ensures state changes are atomic and serialised + self.lock = asyncio.Lock() + def cancel(self): if self.start_task is not None: self.start_task.cancel() @@ -166,11 +173,12 @@ class VoiceSession: return self async def set_tag(self, new_tag): - if self.activity is SessionState.INACTIVE: - raise ValueError("Cannot set tag on an inactive voice session.") - self._tag = new_tag - if self.data is not None: - await self.data.update(tag=new_tag) + async with self.lock: + if self.activity is SessionState.INACTIVE: + raise ValueError("Cannot set tag on an inactive voice session.") + self._tag = new_tag + if self.data is not None: + await self.data.update(tag=new_tag) async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate): """ @@ -194,33 +202,34 @@ class VoiceSession: """ await asyncio.sleep(delay) - logger.debug( - f"Starting voice session for member in guild " - f"and channel ." - ) - # Create the lion if required - await self.bot.core.lions.fetch_member(self.guildid, self.userid) + async with self.lock: + logger.info( + f"Starting voice session for member in guild " + f"and channel ." + ) + # Create the lion if required + await self.bot.core.lions.fetch_member(self.guildid, self.userid) - # Create the tracked channel if required - await self.registry.TrackedChannel.fetch_or_create( - self.state.channelid, guildid=self.guildid, deleted=False - ) + # Create the tracked channel if required + await self.registry.TrackedChannel.fetch_or_create( + self.state.channelid, guildid=self.guildid, deleted=False + ) - # Insert an ongoing_session with the correct state, set data - state = self.state - self.data = await self.registry.VoiceSessionsOngoing.create( - guildid=self.guildid, - userid=self.userid, - channelid=state.channelid, - start_time=start_time, - last_update=start_time, - live_stream=state.stream, - live_video=state.video, - hourly_coins=self.hourly_rate, - tag=self._tag - ) - self.bot.dispatch('voice_session_start', self.data) - self.start_task = None + # Insert an ongoing_session with the correct state, set data + state = self.state + self.data = await self.registry.VoiceSessionsOngoing.create( + guildid=self.guildid, + userid=self.userid, + channelid=state.channelid, + start_time=start_time, + last_update=start_time, + live_stream=state.stream, + live_video=state.video, + hourly_coins=self.hourly_rate, + tag=self._tag + ) + self.bot.dispatch('voice_session_start', self.data) + self.start_task = None def schedule_expiry(self, expire_time): """ @@ -275,33 +284,36 @@ class VoiceSession: """ Close the session, or cancel the pending session. Idempotent. """ - if self.activity is SessionState.ONGOING: - # End the ongoing session - now = utc_now() - await self.data.close_study_session_at(self.guildid, self.userid, now) + async with self.lock: + if self.activity is SessionState.ONGOING: + # End the ongoing session + now = utc_now() + await self.data.close_study_session_at(self.guildid, self.userid, now) - # TODO: Something a bit saner/safer.. dispatch the finished session instead? - self.bot.dispatch('voice_session_end', self.data, now) + # TODO: Something a bit saner/safer.. dispatch the finished session instead? + self.bot.dispatch('voice_session_end', self.data, now) - # Rank update - # TODO: Change to broadcasted event? - rank_cog = self.bot.get_cog('RankCog') - if rank_cog is not None: - asyncio.create_task(rank_cog.on_voice_session_complete( - (self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0) - )) + # Rank update + # TODO: Change to broadcasted event? + rank_cog = self.bot.get_cog('RankCog') + if rank_cog is not None: + asyncio.create_task(rank_cog.on_voice_session_complete( + (self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0) + )) - if self.start_task is not None: - self.start_task.cancel() - self.start_task = None + if self.start_task is not None: + self.start_task.cancel() + self.start_task = None - if self.expiry_task is not None: - self.expiry_task.cancel() - self.expiry_task = None + if self.expiry_task is not None: + self.expiry_task.cancel() + self.expiry_task = None - self.data = None - self.state = None - self.hourly_rate = None + self.data = None + self.state = None + self.hourly_rate = None + self._tag = None + self._start_time = None - # Always release strong reference to session (to allow garbage collection) - self._active_sessions_[self.guildid].pop(self.userid) + # Always release strong reference to session (to allow garbage collection) + self._active_sessions_[self.guildid].pop(self.userid)