fix(voice): Avoid possible race condition.
This commit is contained in:
@@ -73,11 +73,14 @@ class VoiceSession:
|
|||||||
'start_task', 'expiry_task',
|
'start_task', 'expiry_task',
|
||||||
'data', 'state', 'hourly_rate',
|
'data', 'state', 'hourly_rate',
|
||||||
'_tag', '_start_time',
|
'_tag', '_start_time',
|
||||||
|
'lock',
|
||||||
'__weakref__'
|
'__weakref__'
|
||||||
)
|
)
|
||||||
|
|
||||||
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
||||||
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
|
|
||||||
|
# Maintains strong references to active sessions
|
||||||
|
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
|
||||||
|
|
||||||
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
@@ -96,6 +99,10 @@ class VoiceSession:
|
|||||||
self._tag = None
|
self._tag = None
|
||||||
self._start_time = None
|
self._start_time = None
|
||||||
|
|
||||||
|
# Member session lock
|
||||||
|
# Ensures state changes are atomic and serialised
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
if self.start_task is not None:
|
if self.start_task is not None:
|
||||||
self.start_task.cancel()
|
self.start_task.cancel()
|
||||||
@@ -166,11 +173,12 @@ class VoiceSession:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def set_tag(self, new_tag):
|
async def set_tag(self, new_tag):
|
||||||
if self.activity is SessionState.INACTIVE:
|
async with self.lock:
|
||||||
raise ValueError("Cannot set tag on an inactive voice session.")
|
if self.activity is SessionState.INACTIVE:
|
||||||
self._tag = new_tag
|
raise ValueError("Cannot set tag on an inactive voice session.")
|
||||||
if self.data is not None:
|
self._tag = new_tag
|
||||||
await self.data.update(tag=new_tag)
|
if self.data is not None:
|
||||||
|
await self.data.update(tag=new_tag)
|
||||||
|
|
||||||
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
|
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
|
||||||
"""
|
"""
|
||||||
@@ -194,33 +202,34 @@ class VoiceSession:
|
|||||||
"""
|
"""
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
logger.debug(
|
async with self.lock:
|
||||||
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
logger.info(
|
||||||
f"and channel <cid:{self.state.channelid}>."
|
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||||
)
|
f"and channel <cid:{self.state.channelid}>."
|
||||||
# Create the lion if required
|
)
|
||||||
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
|
# Create the lion if required
|
||||||
|
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
|
||||||
|
|
||||||
# Create the tracked channel if required
|
# Create the tracked channel if required
|
||||||
await self.registry.TrackedChannel.fetch_or_create(
|
await self.registry.TrackedChannel.fetch_or_create(
|
||||||
self.state.channelid, guildid=self.guildid, deleted=False
|
self.state.channelid, guildid=self.guildid, deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert an ongoing_session with the correct state, set data
|
# Insert an ongoing_session with the correct state, set data
|
||||||
state = self.state
|
state = self.state
|
||||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||||
guildid=self.guildid,
|
guildid=self.guildid,
|
||||||
userid=self.userid,
|
userid=self.userid,
|
||||||
channelid=state.channelid,
|
channelid=state.channelid,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
last_update=start_time,
|
last_update=start_time,
|
||||||
live_stream=state.stream,
|
live_stream=state.stream,
|
||||||
live_video=state.video,
|
live_video=state.video,
|
||||||
hourly_coins=self.hourly_rate,
|
hourly_coins=self.hourly_rate,
|
||||||
tag=self._tag
|
tag=self._tag
|
||||||
)
|
)
|
||||||
self.bot.dispatch('voice_session_start', self.data)
|
self.bot.dispatch('voice_session_start', self.data)
|
||||||
self.start_task = None
|
self.start_task = None
|
||||||
|
|
||||||
def schedule_expiry(self, expire_time):
|
def schedule_expiry(self, expire_time):
|
||||||
"""
|
"""
|
||||||
@@ -275,33 +284,36 @@ class VoiceSession:
|
|||||||
"""
|
"""
|
||||||
Close the session, or cancel the pending session. Idempotent.
|
Close the session, or cancel the pending session. Idempotent.
|
||||||
"""
|
"""
|
||||||
if self.activity is SessionState.ONGOING:
|
async with self.lock:
|
||||||
# End the ongoing session
|
if self.activity is SessionState.ONGOING:
|
||||||
now = utc_now()
|
# End the ongoing session
|
||||||
await self.data.close_study_session_at(self.guildid, self.userid, now)
|
now = utc_now()
|
||||||
|
await self.data.close_study_session_at(self.guildid, self.userid, now)
|
||||||
|
|
||||||
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
|
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
|
||||||
self.bot.dispatch('voice_session_end', self.data, now)
|
self.bot.dispatch('voice_session_end', self.data, now)
|
||||||
|
|
||||||
# Rank update
|
# Rank update
|
||||||
# TODO: Change to broadcasted event?
|
# TODO: Change to broadcasted event?
|
||||||
rank_cog = self.bot.get_cog('RankCog')
|
rank_cog = self.bot.get_cog('RankCog')
|
||||||
if rank_cog is not None:
|
if rank_cog is not None:
|
||||||
asyncio.create_task(rank_cog.on_voice_session_complete(
|
asyncio.create_task(rank_cog.on_voice_session_complete(
|
||||||
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
|
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
|
||||||
))
|
))
|
||||||
|
|
||||||
if self.start_task is not None:
|
if self.start_task is not None:
|
||||||
self.start_task.cancel()
|
self.start_task.cancel()
|
||||||
self.start_task = None
|
self.start_task = None
|
||||||
|
|
||||||
if self.expiry_task is not None:
|
if self.expiry_task is not None:
|
||||||
self.expiry_task.cancel()
|
self.expiry_task.cancel()
|
||||||
self.expiry_task = None
|
self.expiry_task = None
|
||||||
|
|
||||||
self.data = None
|
self.data = None
|
||||||
self.state = None
|
self.state = None
|
||||||
self.hourly_rate = None
|
self.hourly_rate = None
|
||||||
|
self._tag = None
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
# Always release strong reference to session (to allow garbage collection)
|
# Always release strong reference to session (to allow garbage collection)
|
||||||
self._active_sessions_[self.guildid].pop(self.userid)
|
self._active_sessions_[self.guildid].pop(self.userid)
|
||||||
|
|||||||
Reference in New Issue
Block a user