fix(voice): Avoid possible race condition.
This commit is contained in:
@@ -73,11 +73,14 @@ class VoiceSession:
|
||||
'start_task', 'expiry_task',
|
||||
'data', 'state', 'hourly_rate',
|
||||
'_tag', '_start_time',
|
||||
'lock',
|
||||
'__weakref__'
|
||||
)
|
||||
|
||||
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
||||
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
|
||||
|
||||
# Maintains strong references to active sessions
|
||||
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
|
||||
|
||||
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
||||
self.bot = bot
|
||||
@@ -96,6 +99,10 @@ class VoiceSession:
|
||||
self._tag = None
|
||||
self._start_time = None
|
||||
|
||||
# Member session lock
|
||||
# Ensures state changes are atomic and serialised
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
def cancel(self):
|
||||
if self.start_task is not None:
|
||||
self.start_task.cancel()
|
||||
@@ -166,11 +173,12 @@ class VoiceSession:
|
||||
return self
|
||||
|
||||
async def set_tag(self, new_tag):
|
||||
if self.activity is SessionState.INACTIVE:
|
||||
raise ValueError("Cannot set tag on an inactive voice session.")
|
||||
self._tag = new_tag
|
||||
if self.data is not None:
|
||||
await self.data.update(tag=new_tag)
|
||||
async with self.lock:
|
||||
if self.activity is SessionState.INACTIVE:
|
||||
raise ValueError("Cannot set tag on an inactive voice session.")
|
||||
self._tag = new_tag
|
||||
if self.data is not None:
|
||||
await self.data.update(tag=new_tag)
|
||||
|
||||
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
|
||||
"""
|
||||
@@ -194,33 +202,34 @@ class VoiceSession:
|
||||
"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.debug(
|
||||
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||
f"and channel <cid:{self.state.channelid}>."
|
||||
)
|
||||
# Create the lion if required
|
||||
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
|
||||
async with self.lock:
|
||||
logger.info(
|
||||
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||
f"and channel <cid:{self.state.channelid}>."
|
||||
)
|
||||
# Create the lion if required
|
||||
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
|
||||
|
||||
# Create the tracked channel if required
|
||||
await self.registry.TrackedChannel.fetch_or_create(
|
||||
self.state.channelid, guildid=self.guildid, deleted=False
|
||||
)
|
||||
# Create the tracked channel if required
|
||||
await self.registry.TrackedChannel.fetch_or_create(
|
||||
self.state.channelid, guildid=self.guildid, deleted=False
|
||||
)
|
||||
|
||||
# Insert an ongoing_session with the correct state, set data
|
||||
state = self.state
|
||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||
guildid=self.guildid,
|
||||
userid=self.userid,
|
||||
channelid=state.channelid,
|
||||
start_time=start_time,
|
||||
last_update=start_time,
|
||||
live_stream=state.stream,
|
||||
live_video=state.video,
|
||||
hourly_coins=self.hourly_rate,
|
||||
tag=self._tag
|
||||
)
|
||||
self.bot.dispatch('voice_session_start', self.data)
|
||||
self.start_task = None
|
||||
# Insert an ongoing_session with the correct state, set data
|
||||
state = self.state
|
||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||
guildid=self.guildid,
|
||||
userid=self.userid,
|
||||
channelid=state.channelid,
|
||||
start_time=start_time,
|
||||
last_update=start_time,
|
||||
live_stream=state.stream,
|
||||
live_video=state.video,
|
||||
hourly_coins=self.hourly_rate,
|
||||
tag=self._tag
|
||||
)
|
||||
self.bot.dispatch('voice_session_start', self.data)
|
||||
self.start_task = None
|
||||
|
||||
def schedule_expiry(self, expire_time):
|
||||
"""
|
||||
@@ -275,33 +284,36 @@ class VoiceSession:
|
||||
"""
|
||||
Close the session, or cancel the pending session. Idempotent.
|
||||
"""
|
||||
if self.activity is SessionState.ONGOING:
|
||||
# End the ongoing session
|
||||
now = utc_now()
|
||||
await self.data.close_study_session_at(self.guildid, self.userid, now)
|
||||
async with self.lock:
|
||||
if self.activity is SessionState.ONGOING:
|
||||
# End the ongoing session
|
||||
now = utc_now()
|
||||
await self.data.close_study_session_at(self.guildid, self.userid, now)
|
||||
|
||||
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
|
||||
self.bot.dispatch('voice_session_end', self.data, now)
|
||||
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
|
||||
self.bot.dispatch('voice_session_end', self.data, now)
|
||||
|
||||
# Rank update
|
||||
# TODO: Change to broadcasted event?
|
||||
rank_cog = self.bot.get_cog('RankCog')
|
||||
if rank_cog is not None:
|
||||
asyncio.create_task(rank_cog.on_voice_session_complete(
|
||||
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
|
||||
))
|
||||
# Rank update
|
||||
# TODO: Change to broadcasted event?
|
||||
rank_cog = self.bot.get_cog('RankCog')
|
||||
if rank_cog is not None:
|
||||
asyncio.create_task(rank_cog.on_voice_session_complete(
|
||||
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
|
||||
))
|
||||
|
||||
if self.start_task is not None:
|
||||
self.start_task.cancel()
|
||||
self.start_task = None
|
||||
if self.start_task is not None:
|
||||
self.start_task.cancel()
|
||||
self.start_task = None
|
||||
|
||||
if self.expiry_task is not None:
|
||||
self.expiry_task.cancel()
|
||||
self.expiry_task = None
|
||||
if self.expiry_task is not None:
|
||||
self.expiry_task.cancel()
|
||||
self.expiry_task = None
|
||||
|
||||
self.data = None
|
||||
self.state = None
|
||||
self.hourly_rate = None
|
||||
self.data = None
|
||||
self.state = None
|
||||
self.hourly_rate = None
|
||||
self._tag = None
|
||||
self._start_time = None
|
||||
|
||||
# Always release strong reference to session (to allow garbage collection)
|
||||
self._active_sessions_[self.guildid].pop(self.userid)
|
||||
# Always release strong reference to session (to allow garbage collection)
|
||||
self._active_sessions_[self.guildid].pop(self.userid)
|
||||
|
||||
Reference in New Issue
Block a user