fix(ranks): Tighten up rank refresh.

This commit is contained in:
2023-09-11 08:01:44 +03:00
parent f59f3093d8
commit 6cc253e428
4 changed files with 66 additions and 55 deletions

View File

@@ -1,6 +1,7 @@
from typing import Optional
import asyncio
import datetime
from weakref import WeakValueDictionary
import discord
from discord.ext import commands as cmds
@@ -128,6 +129,9 @@ class RankCog(LionCog):
# pop the guild whenever the season is updated or the rank type changes.
self._member_ranks = {}
# Weakly referenced Locks for each guild to serialise rank actions
self._rank_locks: dict[int, asyncio.Lock] = WeakValueDictionary()
async def cog_load(self):
await self.data.init()
@@ -138,6 +142,13 @@ class RankCog(LionCog):
configcog = self.bot.get_cog('ConfigCog')
self.crossload_group(self.configure_group, configcog.configure_group)
def ranklock(self, guildid):
lock = self._rank_locks.get(guildid, None)
if lock is None:
lock = self._rank_locks[guildid] = asyncio.Lock()
logger.debug(f"Getting rank lock for guild <guildid: {guildid}> (locked: {lock.locked()})")
return lock
# ---------- Event handlers ----------
# season_start setting event handler.. clears the guild season rank cache
@LionCog.listener('on_guildset_season_start')
@@ -257,26 +268,22 @@ class RankCog(LionCog):
"""
Handle batch of completed message sessions.
"""
tasks = []
# TODO: Thread safety
# TODO: Locking between refresh and individual updates
for guildid, userid, messages, guild_xp in session_data:
lguild = await self.bot.core.lions.fetch_guild(guildid)
rank_type = lguild.config.get('rank_type').value
if rank_type in (RankType.MESSAGE, RankType.XP):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid]
session_rank.stat += messages if (rank_type is RankType.MESSAGE) else guild_xp
else:
session_rank = await self.get_member_rank(guildid, userid)
async with self.ranklock(guildid):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid]
session_rank.stat += messages if (rank_type is RankType.MESSAGE) else guild_xp
else:
session_rank = await self.get_member_rank(guildid, userid)
if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required:
tasks.append(asyncio.create_task(self.update_rank(session_rank)))
else:
tasks.append(asyncio.create_task(self._role_check(session_rank)))
if tasks:
await asyncio.gather(*tasks)
if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required:
task = asyncio.create_task(self.update_rank(session_rank), name='update-message-rank')
else:
task = asyncio.create_task(self._role_check(session_rank), name='rank-role-check')
await task
async def _role_check(self, session_rank: SeasonRank):
guild = self.bot.get_guild(session_rank.guildid)
@@ -445,9 +452,6 @@ class RankCog(LionCog):
@log_wrap(action="Voice Rank Hook")
async def on_voice_session_complete(self, *session_data):
tasks = []
# TODO: Thread safety
# TODO: Locking between refresh and individual updates
for guildid, userid, duration, guild_xp in session_data:
lguild = await self.bot.core.lions.fetch_guild(guildid)
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid)
@@ -458,27 +462,28 @@ class RankCog(LionCog):
continue
rank_type = lguild.config.get('rank_type').value
if rank_type in (RankType.VOICE,):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid]
# TODO: Temporary measure
season_start = lguild.config.get('season_start').value or datetime.datetime(1970, 1, 1)
stat_data = self.bot.get_cog('StatsCog').data
session_rank.stat = (await stat_data.VoiceSessionStats.study_times_since(
guildid, userid, season_start)
)[0]
# session_rank.stat += duration if (rank_type is RankType.VOICE) else guild_xp
else:
session_rank = await self.get_member_rank(guildid, userid)
async with self.ranklock(guildid):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid]
# TODO: Temporary measure
season_start = lguild.config.get('season_start').value or datetime.datetime(1970, 1, 1)
stat_data = self.bot.get_cog('StatsCog').data
session_rank.stat = (await stat_data.VoiceSessionStats.study_times_since(
guildid, userid, season_start)
)[0]
# session_rank.stat += duration if (rank_type is RankType.VOICE) else guild_xp
else:
session_rank = await self.get_member_rank(guildid, userid)
if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required:
tasks.append(asyncio.create_task(self.update_rank(session_rank)))
else:
tasks.append(asyncio.create_task(self._role_check(session_rank)))
if tasks:
await asyncio.gather(*tasks)
if session_rank.next_rank is not None and session_rank.stat > session_rank.next_rank.required:
task = asyncio.create_task(self.update_rank(session_rank), name='voice-rank-update')
else:
task = asyncio.create_task(self._role_check(session_rank), name='voice-role-check')
async def on_xp_update(self, *xp_data):
...
# Currently no-op since xp is given purely by message stats
# Implement if xp ever becomes a combination of message and voice stats
pass
@log_wrap(action='interactive rank refresh')
async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild):
@@ -487,9 +492,9 @@ class RankCog(LionCog):
"""
t = self.bot.translator.t
if not interaction.response.is_done():
await interaction.response.defer(thinking=True, ephemeral=False)
await interaction.response.defer(thinking=False)
ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id, timeout=None)
await ui.run(interaction)
await ui.send(interaction.channel)
# Retrieve fresh rank roles
ranks = await self.get_guild_ranks(guild.id, refresh=True)
@@ -655,18 +660,18 @@ class RankCog(LionCog):
# Save correct member ranks and given roles to data
# First clear the member rank data entirely
await self.data.MemberRank.table.delete_where(guildid=guild.id)
column = self._get_rankid_column(rank_type)
values = [
(guild.id, memberid, rank.rankid, rank.roleid)
for memberid, rank in true_member_ranks.items()
]
await self.data.MemberRank.table.insert_many(
('guildid', 'userid', column, 'last_roleid'),
*values
)
if true_member_ranks:
column = self._get_rankid_column(rank_type)
values = [
(guild.id, memberid, rank.rankid, rank.roleid)
for memberid, rank in true_member_ranks.items()
]
await self.data.MemberRank.table.insert_many(
('guildid', 'userid', column, 'last_roleid'),
*values
)
self.flush_guild_ranks(guild.id)
await ui.set_done()
await ui.wait()
# ---------- Commands ----------
@cmds.hybrid_command(name=_p('cmd:ranks', "ranks"))

View File

@@ -31,6 +31,7 @@ class RankOverviewUI(MessageUI):
self.bot = bot
self.guild = guild
self.guildid = guild.id
self.cog = bot.get_cog('RankCog')
self.lguild = None
@@ -99,8 +100,8 @@ class RankOverviewUI(MessageUI):
Refresh the current ranks,
ensuring that all members have the correct rank.
"""
cog = self.bot.get_cog('RankCog')
await cog.interactive_rank_refresh(press, self.guild)
async with self.cog.ranklock(self.guild.id):
await self.cog.interactive_rank_refresh(press, self.guild)
async def refresh_button_refresh(self):
self.refresh_button.label = self.bot.translator.t(_p(
@@ -135,9 +136,10 @@ class RankOverviewUI(MessageUI):
except ResponseTimedOut:
result = False
if result:
await self.rank_model.table.delete_where(guildid=self.guildid)
self.bot.get_cog('RankCog').flush_guild_ranks(self.guild.id)
self.ranks = []
async with self.cog.ranklock(self.guild.id):
await self.rank_model.table.delete_where(guildid=self.guildid)
self.cog.flush_guild_ranks(self.guild.id)
self.ranks = []
await self.redraw()
async def clear_button_refresh(self):

View File

@@ -199,10 +199,11 @@ class RankRefreshUI(MessageUI):
))
value = t(_p(
'ui:refresh_ranks|embed|field:remove|value',
"0 {progress} {total}"
"{progress} {done}/{total} removed"
)).format(
progress=self.progress_bar(self.removed, 0, self.to_remove),
total=self.to_remove,
done=self.removed,
)
embed.add_field(name=name, value=value, inline=False)
else:
@@ -221,10 +222,11 @@ class RankRefreshUI(MessageUI):
))
value = t(_p(
'ui:refresh_ranks|embed|field:add|value',
"0 {progress} {total}"
"{progress} {done}/{total} given"
)).format(
progress=self.progress_bar(self.added, 0, self.to_add),
total=self.to_add,
done=self.added,
)
embed.add_field(name=name, value=value, inline=False)
else: