diff --git a/bot/modules/accountability/commands.py b/bot/modules/accountability/commands.py index c86dae3b..d6b2e069 100644 --- a/bot/modules/accountability/commands.py +++ b/bot/modules/accountability/commands.py @@ -2,6 +2,7 @@ import re import datetime import discord import asyncio +import contextlib from cmdClient.checks import in_guild from meta import client @@ -38,6 +39,26 @@ def time_format(time): time.timestamp() + 3600, ) +user_locks = {} # Map userid -> ctx + + +@contextlib.contextmanager +def ensure_exclusive(ctx): + """ + Cancel any existing exclusive contexts for the author. + """ + old_ctx = user_locks.pop(ctx.author.id, None) + if old_ctx: + [task.cancel() for task in old_ctx.tasks] + + user_locks[ctx.author.id] = ctx + try: + yield + finally: + new_ctx = user_locks.get(ctx.author.id, None) + if new_ctx and new_ctx.msg.id == ctx.msg.id: + user_locks.pop(ctx.author.id) + @module.cmd( name="rooms", @@ -101,87 +122,88 @@ async def cmd_rooms(ctx): valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c') return valid - try: - message = await ctx.client.wait_for('message', check=check, timeout=60) - except asyncio.TimeoutError: + with ensure_exclusive(ctx): try: - await out_msg.edit( - content=None, - embed=discord.Embed( - description="Cancel menu timed out, no accountability sessions were cancelled.", - colour=discord.Colour.red() + message = await ctx.client.wait_for('message', check=check, timeout=60) + except asyncio.TimeoutError: + try: + await out_msg.edit( + content=None, + embed=discord.Embed( + description="Cancel menu timed out, no accountability sessions were cancelled.", + colour=discord.Colour.red() + ) ) - ) - await out_msg.clear_reactions() + await out_msg.clear_reactions() + except discord.HTTPException: + pass + return + + try: + await out_msg.delete() + await message.delete() except discord.HTTPException: pass - return - try: - await out_msg.delete() - await message.delete() - except discord.HTTPException: - pass + if message.content.lower() == 'c': + return - if message.content.lower() == 'c': - return + to_cancel = [ + joined_rows[index] + for index in parse_ranges(message.content) if index < len(joined_rows) + ] + if not to_cancel: + return await ctx.error_reply("No valid bookings selected for cancellation.") + cost = len(to_cancel) * ctx.guild_settings.accountability_price.value - to_cancel = [ - joined_rows[index] - for index in parse_ranges(message.content) if index < len(joined_rows) - ] - if not to_cancel: - return await ctx.error_reply("No valid bookings selected for cancellation.") - cost = len(to_cancel) * ctx.guild_settings.accountability_price.value - - slotids = [row['slotid'] for row in to_cancel] - async with room_lock: - # TODO: Use the return from this to calculate the cost! - accountability_members.delete_where( - userid=ctx.author.id, - slotid=slotids - ) - - # Handle case where the slot has already opened - # TODO: Possible race condition if they open over the hour border? Might never cancel - for row in to_cancel: - aguild = AGuild.cache.get(row['guildid'], None) - if aguild and aguild.upcoming_slot and aguild.upcoming_slot.data: - if aguild.upcoming_slot.data.slotid in slotids: - aguild.upcoming_slot.members.pop(ctx.author.id, None) - if aguild.upcoming_slot.channel: - try: - await aguild.upcoming_slot.channel.set_permissions( - ctx.author, - overwrite=None - ) - except discord.HTTPException: - pass - await aguild.upcoming_slot.update_status() - break - - ctx.alion.addCoins(cost) - - remaining = [row for row in joined_rows if row['slotid'] not in slotids] - if not remaining: - await ctx.embed_reply("Cancelled all your upcoming accountability sessions!") - else: - next_booked_time = min(row['start_at'] for row in remaining) - if len(to_cancel) > 1: - await ctx.embed_reply( - "Cancelled `{}` upcoming sessions!\nYour next session is at .".format( - len(to_cancel), - next_booked_time.timestamp() - ) + slotids = [row['slotid'] for row in to_cancel] + async with room_lock: + # TODO: Use the return from this to calculate the cost! + accountability_members.delete_where( + userid=ctx.author.id, + slotid=slotids ) + + # Handle case where the slot has already opened + # TODO: Possible race condition if they open over the hour border? Might never cancel + for row in to_cancel: + aguild = AGuild.cache.get(row['guildid'], None) + if aguild and aguild.upcoming_slot and aguild.upcoming_slot.data: + if aguild.upcoming_slot.data.slotid in slotids: + aguild.upcoming_slot.members.pop(ctx.author.id, None) + if aguild.upcoming_slot.channel: + try: + await aguild.upcoming_slot.channel.set_permissions( + ctx.author, + overwrite=None + ) + except discord.HTTPException: + pass + await aguild.upcoming_slot.update_status() + break + + ctx.alion.addCoins(cost) + + remaining = [row for row in joined_rows if row['slotid'] not in slotids] + if not remaining: + await ctx.embed_reply("Cancelled all your upcoming accountability sessions!") else: - await ctx.embed_reply( - "Cancelled your session at !\n" - "Your next session is at .".format( - to_cancel[0]['start_at'].timestamp(), - next_booked_time.timestamp() + next_booked_time = min(row['start_at'] for row in remaining) + if len(to_cancel) > 1: + await ctx.embed_reply( + "Cancelled `{}` upcoming sessions!\nYour next session is at .".format( + len(to_cancel), + next_booked_time.timestamp() + ) + ) + else: + await ctx.embed_reply( + "Cancelled your session at !\n" + "Your next session is at .".format( + to_cancel[0]['start_at'].timestamp(), + next_booked_time.timestamp() + ) ) - ) elif command == 'book': # Show booking menu # Get attendee count @@ -239,86 +261,87 @@ async def cmd_rooms(ctx): valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c') return valid - try: - message = await ctx.client.wait_for('message', check=check, timeout=60) - except asyncio.TimeoutError: + with ensure_exclusive(ctx): try: - await out_msg.edit( - content=None, - embed=discord.Embed( - description="Booking menu timed out, no sessions were booked.", - colour=discord.Colour.red() + message = await ctx.client.wait_for('message', check=check, timeout=60) + except asyncio.TimeoutError: + try: + await out_msg.edit( + content=None, + embed=discord.Embed( + description="Booking menu timed out, no sessions were booked.", + colour=discord.Colour.red() + ) ) - ) - await out_msg.clear_reactions() + await out_msg.clear_reactions() + except discord.HTTPException: + pass + return + + try: + await out_msg.delete() + await message.delete() except discord.HTTPException: pass - return - try: - await out_msg.delete() - await message.delete() - except discord.HTTPException: - pass + if message.content.lower() == 'c': + return - if message.content.lower() == 'c': - return - - to_book = [ - times[index] - for index in parse_ranges(message.content) if index < len(times) - ] - if not to_book: - return await ctx.error_reply("No valid sessions selected.") - cost = len(to_book) * ctx.guild_settings.accountability_price.value - if cost > ctx.alion.coins: - return await ctx.error_reply( - "Sorry, booking `{}` sessions costs `{}` coins, and you only have `{}`!".format( - len(to_book), - cost, - ctx.alion.coins + to_book = [ + times[index] + for index in parse_ranges(message.content) if index < len(times) + ] + if not to_book: + return await ctx.error_reply("No valid sessions selected.") + cost = len(to_book) * ctx.guild_settings.accountability_price.value + if cost > ctx.alion.coins: + return await ctx.error_reply( + "Sorry, booking `{}` sessions costs `{}` coins, and you only have `{}`!".format( + len(to_book), + cost, + ctx.alion.coins + ) ) + + # Add the member to data, creating the row if required + slot_rows = accountability_rooms.fetch_rows_where( + guildid=ctx.guild.id, + start_at=to_book + ) + slotids = [row.slotid for row in slot_rows] + to_add = set(to_book).difference((row.start_at for row in slot_rows)) + if to_add: + slotids.extend(row['slotid'] for row in accountability_rooms.insert_many( + *((ctx.guild.id, start_at) for start_at in to_add), + insert_keys=('guildid', 'start_at'), + )) + accountability_members.insert_many( + *((slotid, ctx.author.id, ctx.guild_settings.accountability_price.value) for slotid in slotids), + insert_keys=('slotid', 'userid', 'paid') ) - # Add the member to data, creating the row if required - slot_rows = accountability_rooms.fetch_rows_where( - guildid=ctx.guild.id, - start_at=to_book - ) - slotids = [row.slotid for row in slot_rows] - to_add = set(to_book).difference((row.start_at for row in slot_rows)) - if to_add: - slotids.extend(row['slotid'] for row in accountability_rooms.insert_many( - *((ctx.guild.id, start_at) for start_at in to_add), - insert_keys=('guildid', 'start_at'), - )) - accountability_members.insert_many( - *((slotid, ctx.author.id, ctx.guild_settings.accountability_price.value) for slotid in slotids), - insert_keys=('slotid', 'userid', 'paid') - ) - - # Handle case where the slot has already opened - aguild = AGuild.cache.get(ctx.guild.id, None) - if aguild: - if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book: - slot = aguild.upcoming_slot - if not slot.data: - # Handle slot activation - slot._refresh() - channelid, messageid = await slot.open() - accountability_rooms.update_where( - {'channelid': channelid, 'messageid': messageid}, - slotid=slot.data.slotid - ) - else: - slot.members[ctx.author.id] = SlotMember(slot.data.slotid, ctx.author.id, ctx.guild) - # Also update the channel permissions - try: - await slot.channel.set_permissions(ctx.author, view_channel=True, connect=True) - except discord.HTTPException: - pass - await slot.update_status() - ctx.alion.addCoins(-cost) + # Handle case where the slot has already opened + aguild = AGuild.cache.get(ctx.guild.id, None) + if aguild: + if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book: + slot = aguild.upcoming_slot + if not slot.data: + # Handle slot activation + slot._refresh() + channelid, messageid = await slot.open() + accountability_rooms.update_where( + {'channelid': channelid, 'messageid': messageid}, + slotid=slot.data.slotid + ) + else: + slot.members[ctx.author.id] = SlotMember(slot.data.slotid, ctx.author.id, ctx.guild) + # Also update the channel permissions + try: + await slot.channel.set_permissions(ctx.author, view_channel=True, connect=True) + except discord.HTTPException: + pass + await slot.update_status() + ctx.alion.addCoins(-cost) # Ack purchase embed = discord.Embed(