From b2aa651eaa38c5b2c0b79590d2e9b68e6b74d457 Mon Sep 17 00:00:00 2001 From: Conatum Date: Fri, 24 Sep 2021 21:12:12 +0300 Subject: [PATCH] fix (rooms): Harden against race conditions. Add locking to room init, turnover, and cancellation. Add cleanup of nonexistent members in slot init. Fix an issue where members were being charged for cancelling rooms. --- bot/modules/accountability/TimeSlot.py | 7 +++- bot/modules/accountability/commands.py | 46 ++++++++++++++------------ bot/modules/accountability/tracker.py | 21 ++++++++++-- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/bot/modules/accountability/TimeSlot.py b/bot/modules/accountability/TimeSlot.py index e2c7b44b..fb0f4551 100644 --- a/bot/modules/accountability/TimeSlot.py +++ b/bot/modules/accountability/TimeSlot.py @@ -37,7 +37,7 @@ class SlotMember: @property def member(self): - return self.guild.get_member(self.data.userid) + return self.guild.get_member(self.userid) @property def has_attended(self): @@ -254,6 +254,11 @@ class TimeSlot: Adds the TimeSlot to cache. Returns the (channelid, messageid). """ + # Cleanup any non-existent members + for memid, mem in list(self.members.items()): + if not mem.data or not mem.member: + self.members.pop(memid) + # Calculate overwrites overwrites = { mem.member: self._member_overwrite diff --git a/bot/modules/accountability/commands.py b/bot/modules/accountability/commands.py index dd9bc796..551ddc15 100644 --- a/bot/modules/accountability/commands.py +++ b/bot/modules/accountability/commands.py @@ -11,8 +11,9 @@ from data.conditions import GEQ from .module import module from .lib import utc_now from .tracker import AccountabilityGuild as AGuild +from .tracker import room_lock from .TimeSlot import SlotMember -from .data import accountability_members, accountability_member_info, accountability_open_slots, accountability_rooms +from .data import accountability_members, accountability_member_info, accountability_rooms @module.cmd( @@ -104,29 +105,30 @@ async def cmd_rooms(ctx): cost = len(to_cancel) * ctx.guild_settings.accountability_price.value slotids = [row['slotid'] for row in to_cancel] - accountability_members.delete_where( - userid=ctx.author.id, - slotid=slotids - ) + async with room_lock: + accountability_members.delete_where( + userid=ctx.author.id, + slotid=slotids + ) - # Handle case where the slot has already opened - for row in to_cancel: - aguild = AGuild.cache.get(row['guildid'], None) - if aguild: - if aguild.upcoming_slot and aguild.upcoming_slot.data and (aguild.upcoming_slot.data.slotid in slotids): - aguild.upcoming_slot.members.pop(ctx.author.id, None) - if aguild.upcoming_slot.channel: - try: - await aguild.upcoming_slot.channel.set_permissions( - ctx.author, - overwrite=None - ) - except discord.HTTPException: - pass - await aguild.upcoming_slot.update_status() - break + # Handle case where the slot has already opened + for row in to_cancel: + aguild = AGuild.cache.get(row['guildid'], None) + if aguild and aguild.upcoming_slot and aguild.upcoming_slot.data: + if aguild.upcoming_slot.data.slotid in slotids: + aguild.upcoming_slot.members.pop(ctx.author.id, None) + if aguild.upcoming_slot.channel: + try: + await aguild.upcoming_slot.channel.set_permissions( + ctx.author, + overwrite=None + ) + except discord.HTTPException: + pass + await aguild.upcoming_slot.update_status() + break - ctx.alion.addCoins(-cost) + ctx.alion.addCoins(cost) await ctx.embed_reply( "Successfully canceled your bookings." ) diff --git a/bot/modules/accountability/tracker.py b/bot/modules/accountability/tracker.py index 6a5c2799..19ccb663 100644 --- a/bot/modules/accountability/tracker.py +++ b/bot/modules/accountability/tracker.py @@ -19,6 +19,19 @@ from .module import module voice_ignore_lock = asyncio.Lock() +room_lock = asyncio.Lock() + + +def locker(lock): + """ + Function decorator to wrap the function in a provided Lock + """ + def decorator(func): + async def wrapped(*args, **kwargs): + async with lock: + return await func(*args, **kwargs) + return wrapped + return decorator class AccountabilityGuild: @@ -184,7 +197,7 @@ async def turnover(): ) for slot in current_slots for mem in slot.members.values() - if mem.member.voice and mem.member.voice.channel != slot.channel + if mem.data and mem.member and mem.member.voice and mem.member.voice.channel != slot.channel ) # We return exceptions here to ignore any permission issues that occur with moving members. # It's also possible (likely) that members will move while we are moving other members @@ -279,7 +292,8 @@ async def _accountability_loop(): next_time = next_time + datetime.timedelta(minutes=5) # Open next sessions try: - await open_next(next_time) + async with room_lock: + await open_next(next_time) except Exception: # Unknown exception. Catch it so the loop doesn't die. client.log( @@ -293,7 +307,8 @@ async def _accountability_loop(): elif next_time.minute == 0: # Start new sessions try: - await turnover() + async with room_lock: + await turnover() except Exception: # Unknown exception. Catch it so the loop doesn't die. client.log(