fix(voice): Avoid possible race condition.
This commit is contained in:
@@ -73,11 +73,14 @@ class VoiceSession:
|
|||||||
'start_task', 'expiry_task',
|
'start_task', 'expiry_task',
|
||||||
'data', 'state', 'hourly_rate',
|
'data', 'state', 'hourly_rate',
|
||||||
'_tag', '_start_time',
|
'_tag', '_start_time',
|
||||||
|
'lock',
|
||||||
'__weakref__'
|
'__weakref__'
|
||||||
)
|
)
|
||||||
|
|
||||||
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
||||||
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
|
|
||||||
|
# Maintains strong references to active sessions
|
||||||
|
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
|
||||||
|
|
||||||
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
@@ -96,6 +99,10 @@ class VoiceSession:
|
|||||||
self._tag = None
|
self._tag = None
|
||||||
self._start_time = None
|
self._start_time = None
|
||||||
|
|
||||||
|
# Member session lock
|
||||||
|
# Ensures state changes are atomic and serialised
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
if self.start_task is not None:
|
if self.start_task is not None:
|
||||||
self.start_task.cancel()
|
self.start_task.cancel()
|
||||||
@@ -166,6 +173,7 @@ class VoiceSession:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def set_tag(self, new_tag):
|
async def set_tag(self, new_tag):
|
||||||
|
async with self.lock:
|
||||||
if self.activity is SessionState.INACTIVE:
|
if self.activity is SessionState.INACTIVE:
|
||||||
raise ValueError("Cannot set tag on an inactive voice session.")
|
raise ValueError("Cannot set tag on an inactive voice session.")
|
||||||
self._tag = new_tag
|
self._tag = new_tag
|
||||||
@@ -194,7 +202,8 @@ class VoiceSession:
|
|||||||
"""
|
"""
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
logger.debug(
|
async with self.lock:
|
||||||
|
logger.info(
|
||||||
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||||
f"and channel <cid:{self.state.channelid}>."
|
f"and channel <cid:{self.state.channelid}>."
|
||||||
)
|
)
|
||||||
@@ -275,6 +284,7 @@ class VoiceSession:
|
|||||||
"""
|
"""
|
||||||
Close the session, or cancel the pending session. Idempotent.
|
Close the session, or cancel the pending session. Idempotent.
|
||||||
"""
|
"""
|
||||||
|
async with self.lock:
|
||||||
if self.activity is SessionState.ONGOING:
|
if self.activity is SessionState.ONGOING:
|
||||||
# End the ongoing session
|
# End the ongoing session
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
@@ -302,6 +312,8 @@ class VoiceSession:
|
|||||||
self.data = None
|
self.data = None
|
||||||
self.state = None
|
self.state = None
|
||||||
self.hourly_rate = None
|
self.hourly_rate = None
|
||||||
|
self._tag = None
|
||||||
|
self._start_time = None
|
||||||
|
|
||||||
# Always release strong reference to session (to allow garbage collection)
|
# Always release strong reference to session (to allow garbage collection)
|
||||||
self._active_sessions_[self.guildid].pop(self.userid)
|
self._active_sessions_[self.guildid].pop(self.userid)
|
||||||
|
|||||||
Reference in New Issue
Block a user