diff --git a/src/modules/ranks/cog.py b/src/modules/ranks/cog.py index 16b4bfab..3cc61bad 100644 --- a/src/modules/ranks/cog.py +++ b/src/modules/ranks/cog.py @@ -1,6 +1,7 @@ from typing import Optional import asyncio import datetime +from weakref import WeakValueDictionary import discord from discord.ext import commands as cmds @@ -128,6 +129,9 @@ class RankCog(LionCog): # pop the guild whenever the season is updated or the rank type changes. self._member_ranks = {} + # Weakly referenced Locks for each guild to serialise rank actions + self._rank_locks: dict[int, asyncio.Lock] = WeakValueDictionary() + async def cog_load(self): await self.data.init() @@ -138,6 +142,13 @@ class RankCog(LionCog): configcog = self.bot.get_cog('ConfigCog') self.crossload_group(self.configure_group, configcog.configure_group) + def ranklock(self, guildid): + lock = self._rank_locks.get(guildid, None) + if lock is None: + lock = self._rank_locks[guildid] = asyncio.Lock() + logger.debug(f"Getting rank lock for guild (locked: {lock.locked()})") + return lock + # ---------- Event handlers ---------- # season_start setting event handler.. clears the guild season rank cache @LionCog.listener('on_guildset_season_start') @@ -257,26 +268,22 @@ class RankCog(LionCog): """ Handle batch of completed message sessions. """ - tasks = [] - # TODO: Thread safety - # TODO: Locking between refresh and individual updates for guildid, userid, messages, guild_xp in session_data: lguild = await self.bot.core.lions.fetch_guild(guildid) rank_type = lguild.config.get('rank_type').value if rank_type in (RankType.MESSAGE, RankType.XP): - if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: - session_rank = _members[userid] - session_rank.stat += messages if (rank_type is RankType.MESSAGE) else guild_xp - else: - session_rank = await self.get_member_rank(guildid, userid) + async with self.ranklock(guildid): + if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: + session_rank = _members[userid] + session_rank.stat += messages if (rank_type is RankType.MESSAGE) else guild_xp + else: + session_rank = await self.get_member_rank(guildid, userid) - if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required: - tasks.append(asyncio.create_task(self.update_rank(session_rank))) - else: - tasks.append(asyncio.create_task(self._role_check(session_rank))) - - if tasks: - await asyncio.gather(*tasks) + if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required: + task = asyncio.create_task(self.update_rank(session_rank), name='update-message-rank') + else: + task = asyncio.create_task(self._role_check(session_rank), name='rank-role-check') + await task async def _role_check(self, session_rank: SeasonRank): guild = self.bot.get_guild(session_rank.guildid) @@ -445,9 +452,6 @@ class RankCog(LionCog): @log_wrap(action="Voice Rank Hook") async def on_voice_session_complete(self, *session_data): - tasks = [] - # TODO: Thread safety - # TODO: Locking between refresh and individual updates for guildid, userid, duration, guild_xp in session_data: lguild = await self.bot.core.lions.fetch_guild(guildid) unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid) @@ -458,27 +462,28 @@ class RankCog(LionCog): continue rank_type = lguild.config.get('rank_type').value if rank_type in (RankType.VOICE,): - if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: - session_rank = _members[userid] - # TODO: Temporary measure - season_start = lguild.config.get('season_start').value or datetime.datetime(1970, 1, 1) - stat_data = self.bot.get_cog('StatsCog').data - session_rank.stat = (await stat_data.VoiceSessionStats.study_times_since( - guildid, userid, season_start) - )[0] - # session_rank.stat += duration if (rank_type is RankType.VOICE) else guild_xp - else: - session_rank = await self.get_member_rank(guildid, userid) + async with self.ranklock(guildid): + if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: + session_rank = _members[userid] + # TODO: Temporary measure + season_start = lguild.config.get('season_start').value or datetime.datetime(1970, 1, 1) + stat_data = self.bot.get_cog('StatsCog').data + session_rank.stat = (await stat_data.VoiceSessionStats.study_times_since( + guildid, userid, season_start) + )[0] + # session_rank.stat += duration if (rank_type is RankType.VOICE) else guild_xp + else: + session_rank = await self.get_member_rank(guildid, userid) - if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required: - tasks.append(asyncio.create_task(self.update_rank(session_rank))) - else: - tasks.append(asyncio.create_task(self._role_check(session_rank))) - if tasks: - await asyncio.gather(*tasks) + if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required: + task = asyncio.create_task(self.update_rank(session_rank), name='voice-rank-update') + else: + task = asyncio.create_task(self._role_check(session_rank), name='voice-role-check') async def on_xp_update(self, *xp_data): - ... + # Currently no-op since xp is given purely by message stats + # Implement if xp ever becomes a combination of message and voice stats + pass @log_wrap(action='interactive rank refresh') async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild): @@ -487,9 +492,9 @@ class RankCog(LionCog): """ t = self.bot.translator.t if not interaction.response.is_done(): - await interaction.response.defer(thinking=True, ephemeral=False) + await interaction.response.defer(thinking=False) ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id, timeout=None) - await ui.run(interaction) + await ui.send(interaction.channel) # Retrieve fresh rank roles ranks = await self.get_guild_ranks(guild.id, refresh=True) @@ -655,18 +660,18 @@ class RankCog(LionCog): # Save correct member ranks and given roles to data # First clear the member rank data entirely await self.data.MemberRank.table.delete_where(guildid=guild.id) - column = self._get_rankid_column(rank_type) - values = [ - (guild.id, memberid, rank.rankid, rank.roleid) - for memberid, rank in true_member_ranks.items() - ] - await self.data.MemberRank.table.insert_many( - ('guildid', 'userid', column, 'last_roleid'), - *values - ) + if true_member_ranks: + column = self._get_rankid_column(rank_type) + values = [ + (guild.id, memberid, rank.rankid, rank.roleid) + for memberid, rank in true_member_ranks.items() + ] + await self.data.MemberRank.table.insert_many( + ('guildid', 'userid', column, 'last_roleid'), + *values + ) self.flush_guild_ranks(guild.id) await ui.set_done() - await ui.wait() # ---------- Commands ---------- @cmds.hybrid_command(name=_p('cmd:ranks', "ranks")) diff --git a/src/modules/ranks/ui/overview.py b/src/modules/ranks/ui/overview.py index 41f163c9..90741832 100644 --- a/src/modules/ranks/ui/overview.py +++ b/src/modules/ranks/ui/overview.py @@ -31,6 +31,7 @@ class RankOverviewUI(MessageUI): self.bot = bot self.guild = guild self.guildid = guild.id + self.cog = bot.get_cog('RankCog') self.lguild = None @@ -99,8 +100,8 @@ class RankOverviewUI(MessageUI): Refresh the current ranks, ensuring that all members have the correct rank. """ - cog = self.bot.get_cog('RankCog') - await cog.interactive_rank_refresh(press, self.guild) + async with self.cog.ranklock(self.guild.id): + await self.cog.interactive_rank_refresh(press, self.guild) async def refresh_button_refresh(self): self.refresh_button.label = self.bot.translator.t(_p( @@ -135,9 +136,10 @@ class RankOverviewUI(MessageUI): except ResponseTimedOut: result = False if result: - await self.rank_model.table.delete_where(guildid=self.guildid) - self.bot.get_cog('RankCog').flush_guild_ranks(self.guild.id) - self.ranks = [] + async with self.cog.ranklock(self.guild.id): + await self.rank_model.table.delete_where(guildid=self.guildid) + self.cog.flush_guild_ranks(self.guild.id) + self.ranks = [] await self.redraw() async def clear_button_refresh(self): diff --git a/src/modules/ranks/ui/refresh.py b/src/modules/ranks/ui/refresh.py index 9289c345..e01dc059 100644 --- a/src/modules/ranks/ui/refresh.py +++ b/src/modules/ranks/ui/refresh.py @@ -199,10 +199,11 @@ class RankRefreshUI(MessageUI): )) value = t(_p( 'ui:refresh_ranks|embed|field:remove|value', - "0 {progress} {total}" + "{progress} {done}/{total} removed" )).format( progress=self.progress_bar(self.removed, 0, self.to_remove), total=self.to_remove, + done=self.removed, ) embed.add_field(name=name, value=value, inline=False) else: @@ -221,10 +222,11 @@ class RankRefreshUI(MessageUI): )) value = t(_p( 'ui:refresh_ranks|embed|field:add|value', - "0 {progress} {total}" + "{progress} {done}/{total} given" )).format( progress=self.progress_bar(self.added, 0, self.to_add), total=self.to_add, + done=self.added, ) embed.add_field(name=name, value=value, inline=False) else: diff --git a/src/utils/ui/leo.py b/src/utils/ui/leo.py index e2eedd0d..9c3a4487 100644 --- a/src/utils/ui/leo.py +++ b/src/utils/ui/leo.py @@ -250,6 +250,8 @@ class MessageUI(LeoUI): """ Simple single-message LeoUI, intended as a framework for UIs attached to a single interaction response. + + UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`. """ def __init__(self, *args, callerid: Optional[int] = None, **kwargs):