fix(voice): Avoid possible race condition.

This commit is contained in:
2023-10-08 01:57:54 +03:00
parent 85bbe527be
commit 5ac01d5cb2

View File

@@ -73,11 +73,14 @@ class VoiceSession:
'start_task', 'expiry_task', 'start_task', 'expiry_task',
'data', 'state', 'hourly_rate', 'data', 'state', 'hourly_rate',
'_tag', '_start_time', '_tag', '_start_time',
'lock',
'__weakref__' '__weakref__'
) )
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping _sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
# Maintains strong references to active sessions
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None): def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
self.bot = bot self.bot = bot
@@ -96,6 +99,10 @@ class VoiceSession:
self._tag = None self._tag = None
self._start_time = None self._start_time = None
# Member session lock
# Ensures state changes are atomic and serialised
self.lock = asyncio.Lock()
def cancel(self): def cancel(self):
if self.start_task is not None: if self.start_task is not None:
self.start_task.cancel() self.start_task.cancel()
@@ -166,11 +173,12 @@ class VoiceSession:
return self return self
async def set_tag(self, new_tag): async def set_tag(self, new_tag):
if self.activity is SessionState.INACTIVE: async with self.lock:
raise ValueError("Cannot set tag on an inactive voice session.") if self.activity is SessionState.INACTIVE:
self._tag = new_tag raise ValueError("Cannot set tag on an inactive voice session.")
if self.data is not None: self._tag = new_tag
await self.data.update(tag=new_tag) if self.data is not None:
await self.data.update(tag=new_tag)
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate): async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
""" """
@@ -194,33 +202,34 @@ class VoiceSession:
""" """
await asyncio.sleep(delay) await asyncio.sleep(delay)
logger.debug( async with self.lock:
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> " logger.info(
f"and channel <cid:{self.state.channelid}>." f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
) f"and channel <cid:{self.state.channelid}>."
# Create the lion if required )
await self.bot.core.lions.fetch_member(self.guildid, self.userid) # Create the lion if required
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
# Create the tracked channel if required # Create the tracked channel if required
await self.registry.TrackedChannel.fetch_or_create( await self.registry.TrackedChannel.fetch_or_create(
self.state.channelid, guildid=self.guildid, deleted=False self.state.channelid, guildid=self.guildid, deleted=False
) )
# Insert an ongoing_session with the correct state, set data # Insert an ongoing_session with the correct state, set data
state = self.state state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create( self.data = await self.registry.VoiceSessionsOngoing.create(
guildid=self.guildid, guildid=self.guildid,
userid=self.userid, userid=self.userid,
channelid=state.channelid, channelid=state.channelid,
start_time=start_time, start_time=start_time,
last_update=start_time, last_update=start_time,
live_stream=state.stream, live_stream=state.stream,
live_video=state.video, live_video=state.video,
hourly_coins=self.hourly_rate, hourly_coins=self.hourly_rate,
tag=self._tag tag=self._tag
) )
self.bot.dispatch('voice_session_start', self.data) self.bot.dispatch('voice_session_start', self.data)
self.start_task = None self.start_task = None
def schedule_expiry(self, expire_time): def schedule_expiry(self, expire_time):
""" """
@@ -275,33 +284,36 @@ class VoiceSession:
""" """
Close the session, or cancel the pending session. Idempotent. Close the session, or cancel the pending session. Idempotent.
""" """
if self.activity is SessionState.ONGOING: async with self.lock:
# End the ongoing session if self.activity is SessionState.ONGOING:
now = utc_now() # End the ongoing session
await self.data.close_study_session_at(self.guildid, self.userid, now) now = utc_now()
await self.data.close_study_session_at(self.guildid, self.userid, now)
# TODO: Something a bit saner/safer.. dispatch the finished session instead? # TODO: Something a bit saner/safer.. dispatch the finished session instead?
self.bot.dispatch('voice_session_end', self.data, now) self.bot.dispatch('voice_session_end', self.data, now)
# Rank update # Rank update
# TODO: Change to broadcasted event? # TODO: Change to broadcasted event?
rank_cog = self.bot.get_cog('RankCog') rank_cog = self.bot.get_cog('RankCog')
if rank_cog is not None: if rank_cog is not None:
asyncio.create_task(rank_cog.on_voice_session_complete( asyncio.create_task(rank_cog.on_voice_session_complete(
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0) (self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
)) ))
if self.start_task is not None: if self.start_task is not None:
self.start_task.cancel() self.start_task.cancel()
self.start_task = None self.start_task = None
if self.expiry_task is not None: if self.expiry_task is not None:
self.expiry_task.cancel() self.expiry_task.cancel()
self.expiry_task = None self.expiry_task = None
self.data = None self.data = None
self.state = None self.state = None
self.hourly_rate = None self.hourly_rate = None
self._tag = None
self._start_time = None
# Always release strong reference to session (to allow garbage collection) # Always release strong reference to session (to allow garbage collection)
self._active_sessions_[self.guildid].pop(self.userid) self._active_sessions_[self.guildid].pop(self.userid)