feat: Implement rank refresh.

This commit is contained in:
2023-08-24 14:30:03 +03:00
parent df9b835cd5
commit 9b2af56d64
6 changed files with 571 additions and 3 deletions

View File

@@ -13,12 +13,14 @@ from wards import high_management_ward, high_management_iward
from core.data import RankType from core.data import RankType
from utils.ui import ChoicedEnum, Transformed from utils.ui import ChoicedEnum, Transformed
from utils.lib import utc_now, replace_multiple from utils.lib import utc_now, replace_multiple
from utils.ratelimits import Bucket, limit_concurrency
from utils.data import TemporaryTable
from . import babel, logger from . import babel, logger
from .data import RankData, AnyRankData from .data import RankData, AnyRankData
from .settings import RankSettings from .settings import RankSettings
from .ui import RankOverviewUI, RankConfigUI from .ui import RankOverviewUI, RankConfigUI, RankRefreshUI
from .utils import rank_model_from_type, format_stat_range from .utils import rank_model_from_type, format_stat_range
_p = babel._p _p = babel._p
@@ -158,6 +160,27 @@ class RankCog(LionCog):
self._member_ranks[guildid] = cached self._member_ranks[guildid] = cached
return cached return cached
def _get_stats_model(self, rank_type):
return {
RankType.MESSAGE: self.bot.get_cog('TextTrackerCog').data.TextSessions,
RankType.VOICE: self.bot.get_cog('StatsCog').data.VoiceSessionStats,
RankType.XP: self.bot.get_cog('StatsCog').data.MemberExp,
}[rank_type]
def _get_rank_model(self, rank_type):
return {
RankType.MESSAGE: self.data.MsgRank,
RankType.VOICE: self.data.VoiceRank,
RankType.XP: self.data.XPRank,
}[rank_type]
def _get_rankid_column(self, rank_type):
return {
RankType.MESSAGE: 'current_msg_rankid',
RankType.VOICE: 'current_voice_rankid',
RankType.XP: 'current_xp_rankid'
}[rank_type]
async def get_member_rank(self, guildid: int, userid: int) -> SeasonRank: async def get_member_rank(self, guildid: int, userid: int) -> SeasonRank:
""" """
Fetch the SeasonRank info for the given member. Fetch the SeasonRank info for the given member.
@@ -409,8 +432,14 @@ class RankCog(LionCog):
# TODO: Locking between refresh and individual updates # TODO: Locking between refresh and individual updates
for guildid, userid, duration, guild_xp in session_data: for guildid, userid, duration, guild_xp in session_data:
lguild = await self.bot.core.lions.fetch_guild(guildid) lguild = await self.bot.core.lions.fetch_guild(guildid)
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid)
unranked_roleids = set(unranked_role_setting.data)
guild = self.bot.get_guild(guildid)
member = guild.get_member(userid) if guild else None
if not member or member.bot or any (role.id in unranked_roleids for role in member.roles):
continue
rank_type = lguild.config.get('rank_type').value rank_type = lguild.config.get('rank_type').value
if rank_type in (RankType.VOICE, RankType.XP): if rank_type in (RankType.VOICE,):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid] session_rank = _members[userid]
# TODO: Temporary measure # TODO: Temporary measure
@@ -433,6 +462,201 @@ class RankCog(LionCog):
async def on_xp_update(self, *xp_data): async def on_xp_update(self, *xp_data):
... ...
async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild):
"""
Interactively update ranks for everyone in the given guild.
"""
t = self.bot.translator.t
if not interaction.response.is_done():
await interaction.response.defer(thinking=True, ephemeral=False)
ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id)
await ui.run(interaction)
# Retrieve fresh rank roles
ranks = await self.get_guild_ranks(guild.id, refresh=True)
ui.stage_ranks = True
ui.poke()
# Ensure guild is chunked
if not guild.chunked:
members = await guild.chunk()
else:
members = guild.members
ui.stage_members = True
ui.poke()
roles = {rank.roleid: guild.get_role(rank.roleid) for rank in ranks}
if not all(roles.values()):
error = t(_p(
'rank_refresh|error:roles_dne|desc',
"Some ranks have invalid or deleted roles! Please remove them first."
))
await ui.set_error(error)
return
# Check that bot has permission to assign rank roles
failing = [role for role in roles.values() if not role.is_assignable()]
if failing:
error = t(_p(
'rank_refresh|error:unassignable_roles|desc',
"I have insufficient permissions to assign the following role(s):\n{roles}"
)).format(roles='\n'.join(role.mention for role in failing)),
await ui.set_error(error)
return
ui.stage_roles = True
ui.poke()
# Now we are certain that all the rank roles exist and are assignable
# Compute season start and season leaderboard
lguild = await self.bot.core.lions.fetch_guild(guild.id)
season_start = lguild.config.get('season_start').value
rank_type = lguild.config.get('rank_type').value
stats_model = self._get_stats_model(rank_type)
if season_start:
leaderboard = await stats_model.leaderboard_since(guild.id, season_start)
else:
leaderboard = await stats_model.leaderboard_all(guild.id)
# Compile map of correct ranks
# Filtering out members who are untracked or not in server
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guild.id)
unranked_roleids = set(unranked_role_setting.data)
true_member_ranks: dict[int, RankData.VoiceRank | RankData.XPRank | RankData.MsgRank] = {}
for userid, stat_total in leaderboard:
# Check member exists
if member := guild.get_member(userid):
# Check member does not have unranked roles
if not (member.bot or any(role.id in unranked_roleids for role in member.roles)):
# Compute member rank
rank = next((rank for rank in reversed(ranks) if rank.required <= stat_total), None)
if rank is not None:
true_member_ranks[userid] = rank
# Compile maps of member roles that need removal and member roles that need adding
to_remove: list[tuple[discord.Member, list[discord.Role]]] = []
to_add: list[tuple[discord.Member, discord.Role]] = []
for member in members:
if member.bot:
continue
true_rank = true_member_ranks.get(member.id, None)
true_roleid = true_rank.roleid if true_rank is not None else None
has_true = (true_roleid is None)
invalid = []
for role in member.roles:
if role.id in roles:
if not has_true and role.id == true_roleid:
has_true = True
else:
invalid.append(role)
if invalid:
to_remove.append((member, invalid))
if not has_true:
to_add.append((member, roles[true_roleid]))
ui.stage_compute = True
ui.to_remove = len(to_remove)
ui.to_add = len(to_add)
ui.poke()
# Perform operations
# Starting with removals
coros = []
bucket = Bucket(4, 5)
for member, roles in to_remove:
remove_coro = member.remove_roles(
*roles,
reason=t(_p(
'rank_refresh|remove_roles|audit',
"Removing invalid rank role."
))
)
coros.append(bucket.wrapped(remove_coro))
index = 0
async for task in limit_concurrency(coros, 5):
try:
await task
index += 1
ui.poke()
except discord.HTTPException:
error = t(_p(
'rank_refresh|remove_roles|small_error',
"*Could not remove ranks from {member}*"
)).format(member=to_remove[index][0].mention)
self.ui.errors.append(error)
if len(self.ui.errors) > 10:
await ui.set_error(
t(_p(
'rank_refresh|remove_roles|error:too_many_issues',
"Too many issues occurred while removing ranks! "
"Please check my permissions and try again in a few minutes."
))
)
return
ui.removed += 1
ui.poke()
coros = []
for member, role in to_add:
add_coro = member.add_roles(
role,
reason=t(_p(
'rank_refresh|add_roles|audit',
"Adding rank role from refresh"
))
)
coros.append(bucket.wrapped(add_coro))
index = 0
async for task in limit_concurrency(coros, 5):
try:
await task
index += 1
ui.poke()
except discord.HTTPException:
error = t(_p(
'rank_refresh|add_roles|small_error',
"*Could not add {role} to {member}*"
)).format(member=to_add[index][0].mention, role=to_add[index][1].mention)
self.ui.errors.append(error)
if len(self.ui.errors) > 10:
await ui.set_error(
t(_p(
'rank_refresh|add_roles|error:too_many_issues',
"Too many issues occurred while adding ranks! "
"Please check my permissions and try again in a few minutes."
))
)
return
ui.added += 1
ui.poke()
# Save correct member ranks and given roles to data
# First clear the member rank data entirely
await self.data.MemberRank.table.delete_where(guildid=guild.id)
column = self._get_rankid_column(rank_type)
tmptable = TemporaryTable(
'_gid', '_uid', '_rankid', '_roleid',
types=('BIGINT', 'BIGINT', 'BIGINT', 'BIGINT')
)
tmptable.values = [
(guild.id, memberid, rank.rankid, rank.roleid)
for memberid, rank in true_member_ranks.items()
]
if tmptable.values:
await self.data.MemberRank.table.update_where(
guildid=tmptable['_gid'],
userid=tmptable['_uid']
).set(
**{column: tmptable['_rankid'], 'last_roleid': tmptable['_roleid']}
).from_expr(tmptable)
self.flush_guild_ranks(guild.id)
await ui.set_done()
await ui.wait()
# ---------- Commands ---------- # ---------- Commands ----------
@cmds.hybrid_command(name=_p('cmd:ranks', "ranks")) @cmds.hybrid_command(name=_p('cmd:ranks', "ranks"))
async def ranks_cmd(self, ctx: LionContext): async def ranks_cmd(self, ctx: LionContext):

View File

@@ -2,3 +2,4 @@ from .editor import RankEditor
from .preview import RankPreviewUI from .preview import RankPreviewUI
from .overview import RankOverviewUI from .overview import RankOverviewUI
from .config import RankConfigUI from .config import RankConfigUI
from .refresh import RankRefreshUI

View File

@@ -98,7 +98,8 @@ class RankOverviewUI(MessageUI):
Refresh the current ranks, Refresh the current ranks,
ensuring that all members have the correct rank. ensuring that all members have the correct rank.
""" """
await press.response.send_message("Not Implemented Yet") cog = self.bot.get_cog('RankCog')
await cog.interactive_rank_refresh(press, self.guild)
async def refresh_button_refresh(self): async def refresh_button_refresh(self):
self.refresh_button.label = self.bot.translator.t(_p( self.refresh_button.label = self.bot.translator.t(_p(

View File

@@ -0,0 +1,259 @@
from typing import Optional
import asyncio
import discord
from discord.ui.select import select, Select, SelectOption, RoleSelect
from discord.ui.button import button, Button, ButtonStyle
from meta import conf, LionBot
from meta.logger import log_wrap
from core.data import RankType
from data import ORDER
from utils.ui import MessageUI
from utils.lib import MessageArgs, utc_now
from babel.translator import ctx_translator
from .. import babel, logger
from ..data import AnyRankData
from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value
from .editor import RankEditor
from .preview import RankPreviewUI
_p = babel._p
class RankRefreshUI(MessageUI):
def __init__(self, bot: LionBot, guild: discord.Guild, **kwargs):
super().__init__(**kwargs)
self.bot = bot
self.guild = guild
self.stage_ranks = None
self.stage_members = None
self.stage_roles = None
self.stage_compute = None
self.to_remove = 0
self.to_add = 0
self.removed = 0
self.added = 0
self.error: Optional[str] = None
self.done = False
self.errors: list[str] = []
self._loop_task: Optional[asyncio.Task] = None
self._wakeup = asyncio.Event()
# ----- API -----
async def set_error(self, error: str):
"""
Set the given error, refresh, and stop.
"""
self.error = error
await self.refresh()
await self.close()
async def set_done(self):
self.done = True
await self.refresh()
await self.close()
def poke(self):
self._wakeup.set()
async def run(self, *args, **kwargs):
await super().run(*args, **kwargs)
self._loop_task = asyncio.create_task(self._refresh_loop(), name='refresh ui loop')
async def cleanup(self):
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
await super().cleanup()
def progress_bar(self, value, minimum, maximum, width=10) -> str:
"""
Build a text progress bar representing `value` between `minimum` and `maximum`.
"""
emojis = self.bot.config.emojis
proportion = (value - minimum) / (maximum - minimum)
sections = min(max(int(proportion * width), 0), width)
bar = []
# Starting segment
bar.append(str(emojis.progress_left_empty) if sections == 0 else str(emojis.progress_left_full))
# Full segments up to transition or end
if sections >= 2:
bar.append(str(emojis.progress_middle_full) * (sections - 2))
# Transition, if required
if 1 < sections < width:
bar.append(str(emojis.progress_middle_transition))
# Empty sections up to end
if sections < width:
bar.append(str(emojis.progress_middle_empty) * (width - max(sections, 1) - 1))
# End section
bar.append(str(emojis.progress_right_empty) if sections < width else str(emojis.progress_right_full))
# Join all the sections together and return
return ''.join(bar)
@log_wrap('refresh ui loop')
async def _refresh_loop(self):
while True:
try:
await asyncio.sleep(1)
await self._wakeup.wait()
await self.refresh()
except asyncio.CancelledError:
break
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
t = self.bot.translator.t
errored = bool(self.error)
if errored:
waiting_emoji = self.bot.config.emojis.cancel
title = t(_p(
'ui:refresh_ranks|embed|title:errored',
"Could not refresh the server ranks!"
))
colour = discord.Colour.brand_red()
else:
waiting_emoji = self.bot.config.emojis.loading
if self.done:
title = t(_p(
'ui:refresh_ranks|embed|title:done',
"Rank refresh complete!"
))
colour = discord.Colour.brand_green()
else:
title = t(_p(
'ui:refresh_ranks|embed|title:working',
"Refreshing your server ranks, please wait."
))
colour = discord.Colour.orange()
embed = discord.Embed(
colour=colour,
title=title,
timestamp=utc_now()
)
lines = []
stop_here = False
if not stop_here:
stage = self.stage_ranks
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:ranks',
"**Loading server ranks:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_members
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:members',
"**Loading server members:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_roles
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:roles',
"**Loading rank roles:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_compute
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:compute',
"**Computing correct ranks:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
lines.append("")
if self.to_remove > self.removed and not errored:
# Still have members to remove, show loading bar
name = t(_p(
'ui:refresh_ranks|embed|field:remove|name',
"Removing invalid rank roles from members"
))
value = t(_p(
'ui:refresh_ranks|embed|field:remove|value',
"0 {progress} {total}"
)).format(
progress=self.progress_bar(self.removed, 0, self.to_remove),
total=self.to_remove,
)
embed.add_field(name=name, value=value, inline=False)
else:
emoji = self.bot.config.emojis.tick
text = t(_p(
'ui:refresh_ranks|embed|line:remove',
"**Removed invalid ranks:** {done}/{target}"
)).format(done=self.removed, target=self.to_remove)
lines.append(text)
if self.to_add > self.added and not errored:
# Still have members to add, show loading bar
name = t(_p(
'ui:refresh_ranks|embed|field:add|name',
"Giving members their rank roles"
))
value = t(_p(
'ui:refresh_ranks|embed|field:add|value',
"0 {progress} {total}"
)).format(
progress=self.progress_bar(self.added, 0, self.to_add),
total=self.to_add,
)
embed.add_field(name=name, value=value, inline=False)
else:
emoji = self.bot.config.emojis.tick
text = t(_p(
'ui:refresh_ranks|embed|line:add',
"**Updated member ranks:** {done}/{target}"
)).format(done=self.added, target=self.to_add)
lines.append(text)
embed.description = '\n'.join(lines)
if self.errors:
name = (
'ui:refresh_ranks|embed|field:errors|title',
"Issues"
)
value = '\n'.join(self.errors)
embed.add_field(name=name, value=value, inline=False)
if self.error:
name = (
'ui:refresh_ranks|embed|field:critical|title',
"Critical Error! Cannot complete refresh"
)
embed.add_field(name=name, value=self.error, inline=False)
return MessageArgs(embed=embed)
async def refresh_layout(self):
pass
async def reload(self):
pass

View File

@@ -288,5 +288,53 @@ class TextTrackerData(Registry):
tuple(chain((userid, guildid), points)) tuple(chain((userid, guildid), points))
) )
return [r['messages'] or 0 for r in await cursor.fetchall()] return [r['messages'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='msgs_leaderboard_all')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the message count totals for the given guild since the given time.
"""
query = sql.SQL(
"""
SELECT userid, sum(messages) as user_total
FROM text_sessions
WHERE guildid = %s AND start_time >= %s
GROUP BY userid
ORDER BY
"""
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since))
leaderboard = [
(row['userid'], int(row['user_total']))
for row in await cursor.fetchall()
]
return leaderboard
@classmethod
@log_wrap(action='msgs_leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time message count totals for the given guild.
"""
query = sql.SQL(
"""
SELECT userid, sum(messages) as user_total
FROM text_sessions
WHERE guildid = %s
GROUP BY userid
ORDER BY
"""
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid,))
leaderboard = [
(row['userid'], int(row['user_total']))
for row in await cursor.fetchall()
]
return leaderboard
untracked_channels = Table('untracked_text_channels') untracked_channels = Table('untracked_text_channels')

View File

@@ -1,10 +1,14 @@
import asyncio import asyncio
import time import time
import logging
from meta.errors import SafeCancellation from meta.errors import SafeCancellation
from cachetools import TTLCache from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception): class BucketFull(Exception):
""" """
@@ -129,3 +133,34 @@ class RateLimit:
return await func(ctx, *args, **kwargs) return await func(ctx, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")