fix(voice): Avoid possible race condition.

This commit is contained in:
2023-10-08 01:57:54 +03:00
parent 85bbe527be
commit 5ac01d5cb2

View File

@@ -73,11 +73,14 @@ class VoiceSession:
'start_task', 'expiry_task',
'data', 'state', 'hourly_rate',
'_tag', '_start_time',
'lock',
'__weakref__'
)
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
# Maintains strong references to active sessions
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
self.bot = bot
@@ -96,6 +99,10 @@ class VoiceSession:
self._tag = None
self._start_time = None
# Member session lock
# Ensures state changes are atomic and serialised
self.lock = asyncio.Lock()
def cancel(self):
if self.start_task is not None:
self.start_task.cancel()
@@ -166,11 +173,12 @@ class VoiceSession:
return self
async def set_tag(self, new_tag):
if self.activity is SessionState.INACTIVE:
raise ValueError("Cannot set tag on an inactive voice session.")
self._tag = new_tag
if self.data is not None:
await self.data.update(tag=new_tag)
async with self.lock:
if self.activity is SessionState.INACTIVE:
raise ValueError("Cannot set tag on an inactive voice session.")
self._tag = new_tag
if self.data is not None:
await self.data.update(tag=new_tag)
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
"""
@@ -194,33 +202,34 @@ class VoiceSession:
"""
await asyncio.sleep(delay)
logger.debug(
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
f"and channel <cid:{self.state.channelid}>."
)
# Create the lion if required
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
async with self.lock:
logger.info(
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
f"and channel <cid:{self.state.channelid}>."
)
# Create the lion if required
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
# Create the tracked channel if required
await self.registry.TrackedChannel.fetch_or_create(
self.state.channelid, guildid=self.guildid, deleted=False
)
# Create the tracked channel if required
await self.registry.TrackedChannel.fetch_or_create(
self.state.channelid, guildid=self.guildid, deleted=False
)
# Insert an ongoing_session with the correct state, set data
state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create(
guildid=self.guildid,
userid=self.userid,
channelid=state.channelid,
start_time=start_time,
last_update=start_time,
live_stream=state.stream,
live_video=state.video,
hourly_coins=self.hourly_rate,
tag=self._tag
)
self.bot.dispatch('voice_session_start', self.data)
self.start_task = None
# Insert an ongoing_session with the correct state, set data
state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create(
guildid=self.guildid,
userid=self.userid,
channelid=state.channelid,
start_time=start_time,
last_update=start_time,
live_stream=state.stream,
live_video=state.video,
hourly_coins=self.hourly_rate,
tag=self._tag
)
self.bot.dispatch('voice_session_start', self.data)
self.start_task = None
def schedule_expiry(self, expire_time):
"""
@@ -275,33 +284,36 @@ class VoiceSession:
"""
Close the session, or cancel the pending session. Idempotent.
"""
if self.activity is SessionState.ONGOING:
# End the ongoing session
now = utc_now()
await self.data.close_study_session_at(self.guildid, self.userid, now)
async with self.lock:
if self.activity is SessionState.ONGOING:
# End the ongoing session
now = utc_now()
await self.data.close_study_session_at(self.guildid, self.userid, now)
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
self.bot.dispatch('voice_session_end', self.data, now)
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
self.bot.dispatch('voice_session_end', self.data, now)
# Rank update
# TODO: Change to broadcasted event?
rank_cog = self.bot.get_cog('RankCog')
if rank_cog is not None:
asyncio.create_task(rank_cog.on_voice_session_complete(
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
))
# Rank update
# TODO: Change to broadcasted event?
rank_cog = self.bot.get_cog('RankCog')
if rank_cog is not None:
asyncio.create_task(rank_cog.on_voice_session_complete(
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
))
if self.start_task is not None:
self.start_task.cancel()
self.start_task = None
if self.start_task is not None:
self.start_task.cancel()
self.start_task = None
if self.expiry_task is not None:
self.expiry_task.cancel()
self.expiry_task = None
if self.expiry_task is not None:
self.expiry_task.cancel()
self.expiry_task = None
self.data = None
self.state = None
self.hourly_rate = None
self.data = None
self.state = None
self.hourly_rate = None
self._tag = None
self._start_time = None
# Always release strong reference to session (to allow garbage collection)
self._active_sessions_[self.guildid].pop(self.userid)
# Always release strong reference to session (to allow garbage collection)
self._active_sessions_[self.guildid].pop(self.userid)