feat: Implement rank refresh.

This commit is contained in:
2023-08-24 14:30:03 +03:00
parent df9b835cd5
commit 9b2af56d64
6 changed files with 571 additions and 3 deletions

View File

@@ -13,12 +13,14 @@ from wards import high_management_ward, high_management_iward
from core.data import RankType
from utils.ui import ChoicedEnum, Transformed
from utils.lib import utc_now, replace_multiple
from utils.ratelimits import Bucket, limit_concurrency
from utils.data import TemporaryTable
from . import babel, logger
from .data import RankData, AnyRankData
from .settings import RankSettings
from .ui import RankOverviewUI, RankConfigUI
from .ui import RankOverviewUI, RankConfigUI, RankRefreshUI
from .utils import rank_model_from_type, format_stat_range
_p = babel._p
@@ -158,6 +160,27 @@ class RankCog(LionCog):
self._member_ranks[guildid] = cached
return cached
def _get_stats_model(self, rank_type):
return {
RankType.MESSAGE: self.bot.get_cog('TextTrackerCog').data.TextSessions,
RankType.VOICE: self.bot.get_cog('StatsCog').data.VoiceSessionStats,
RankType.XP: self.bot.get_cog('StatsCog').data.MemberExp,
}[rank_type]
def _get_rank_model(self, rank_type):
return {
RankType.MESSAGE: self.data.MsgRank,
RankType.VOICE: self.data.VoiceRank,
RankType.XP: self.data.XPRank,
}[rank_type]
def _get_rankid_column(self, rank_type):
return {
RankType.MESSAGE: 'current_msg_rankid',
RankType.VOICE: 'current_voice_rankid',
RankType.XP: 'current_xp_rankid'
}[rank_type]
async def get_member_rank(self, guildid: int, userid: int) -> SeasonRank:
"""
Fetch the SeasonRank info for the given member.
@@ -409,8 +432,14 @@ class RankCog(LionCog):
# TODO: Locking between refresh and individual updates
for guildid, userid, duration, guild_xp in session_data:
lguild = await self.bot.core.lions.fetch_guild(guildid)
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid)
unranked_roleids = set(unranked_role_setting.data)
guild = self.bot.get_guild(guildid)
member = guild.get_member(userid) if guild else None
if not member or member.bot or any (role.id in unranked_roleids for role in member.roles):
continue
rank_type = lguild.config.get('rank_type').value
if rank_type in (RankType.VOICE, RankType.XP):
if rank_type in (RankType.VOICE,):
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
session_rank = _members[userid]
# TODO: Temporary measure
@@ -433,6 +462,201 @@ class RankCog(LionCog):
async def on_xp_update(self, *xp_data):
...
async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild):
"""
Interactively update ranks for everyone in the given guild.
"""
t = self.bot.translator.t
if not interaction.response.is_done():
await interaction.response.defer(thinking=True, ephemeral=False)
ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id)
await ui.run(interaction)
# Retrieve fresh rank roles
ranks = await self.get_guild_ranks(guild.id, refresh=True)
ui.stage_ranks = True
ui.poke()
# Ensure guild is chunked
if not guild.chunked:
members = await guild.chunk()
else:
members = guild.members
ui.stage_members = True
ui.poke()
roles = {rank.roleid: guild.get_role(rank.roleid) for rank in ranks}
if not all(roles.values()):
error = t(_p(
'rank_refresh|error:roles_dne|desc',
"Some ranks have invalid or deleted roles! Please remove them first."
))
await ui.set_error(error)
return
# Check that bot has permission to assign rank roles
failing = [role for role in roles.values() if not role.is_assignable()]
if failing:
error = t(_p(
'rank_refresh|error:unassignable_roles|desc',
"I have insufficient permissions to assign the following role(s):\n{roles}"
)).format(roles='\n'.join(role.mention for role in failing)),
await ui.set_error(error)
return
ui.stage_roles = True
ui.poke()
# Now we are certain that all the rank roles exist and are assignable
# Compute season start and season leaderboard
lguild = await self.bot.core.lions.fetch_guild(guild.id)
season_start = lguild.config.get('season_start').value
rank_type = lguild.config.get('rank_type').value
stats_model = self._get_stats_model(rank_type)
if season_start:
leaderboard = await stats_model.leaderboard_since(guild.id, season_start)
else:
leaderboard = await stats_model.leaderboard_all(guild.id)
# Compile map of correct ranks
# Filtering out members who are untracked or not in server
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guild.id)
unranked_roleids = set(unranked_role_setting.data)
true_member_ranks: dict[int, RankData.VoiceRank | RankData.XPRank | RankData.MsgRank] = {}
for userid, stat_total in leaderboard:
# Check member exists
if member := guild.get_member(userid):
# Check member does not have unranked roles
if not (member.bot or any(role.id in unranked_roleids for role in member.roles)):
# Compute member rank
rank = next((rank for rank in reversed(ranks) if rank.required <= stat_total), None)
if rank is not None:
true_member_ranks[userid] = rank
# Compile maps of member roles that need removal and member roles that need adding
to_remove: list[tuple[discord.Member, list[discord.Role]]] = []
to_add: list[tuple[discord.Member, discord.Role]] = []
for member in members:
if member.bot:
continue
true_rank = true_member_ranks.get(member.id, None)
true_roleid = true_rank.roleid if true_rank is not None else None
has_true = (true_roleid is None)
invalid = []
for role in member.roles:
if role.id in roles:
if not has_true and role.id == true_roleid:
has_true = True
else:
invalid.append(role)
if invalid:
to_remove.append((member, invalid))
if not has_true:
to_add.append((member, roles[true_roleid]))
ui.stage_compute = True
ui.to_remove = len(to_remove)
ui.to_add = len(to_add)
ui.poke()
# Perform operations
# Starting with removals
coros = []
bucket = Bucket(4, 5)
for member, roles in to_remove:
remove_coro = member.remove_roles(
*roles,
reason=t(_p(
'rank_refresh|remove_roles|audit',
"Removing invalid rank role."
))
)
coros.append(bucket.wrapped(remove_coro))
index = 0
async for task in limit_concurrency(coros, 5):
try:
await task
index += 1
ui.poke()
except discord.HTTPException:
error = t(_p(
'rank_refresh|remove_roles|small_error',
"*Could not remove ranks from {member}*"
)).format(member=to_remove[index][0].mention)
self.ui.errors.append(error)
if len(self.ui.errors) > 10:
await ui.set_error(
t(_p(
'rank_refresh|remove_roles|error:too_many_issues',
"Too many issues occurred while removing ranks! "
"Please check my permissions and try again in a few minutes."
))
)
return
ui.removed += 1
ui.poke()
coros = []
for member, role in to_add:
add_coro = member.add_roles(
role,
reason=t(_p(
'rank_refresh|add_roles|audit',
"Adding rank role from refresh"
))
)
coros.append(bucket.wrapped(add_coro))
index = 0
async for task in limit_concurrency(coros, 5):
try:
await task
index += 1
ui.poke()
except discord.HTTPException:
error = t(_p(
'rank_refresh|add_roles|small_error',
"*Could not add {role} to {member}*"
)).format(member=to_add[index][0].mention, role=to_add[index][1].mention)
self.ui.errors.append(error)
if len(self.ui.errors) > 10:
await ui.set_error(
t(_p(
'rank_refresh|add_roles|error:too_many_issues',
"Too many issues occurred while adding ranks! "
"Please check my permissions and try again in a few minutes."
))
)
return
ui.added += 1
ui.poke()
# Save correct member ranks and given roles to data
# First clear the member rank data entirely
await self.data.MemberRank.table.delete_where(guildid=guild.id)
column = self._get_rankid_column(rank_type)
tmptable = TemporaryTable(
'_gid', '_uid', '_rankid', '_roleid',
types=('BIGINT', 'BIGINT', 'BIGINT', 'BIGINT')
)
tmptable.values = [
(guild.id, memberid, rank.rankid, rank.roleid)
for memberid, rank in true_member_ranks.items()
]
if tmptable.values:
await self.data.MemberRank.table.update_where(
guildid=tmptable['_gid'],
userid=tmptable['_uid']
).set(
**{column: tmptable['_rankid'], 'last_roleid': tmptable['_roleid']}
).from_expr(tmptable)
self.flush_guild_ranks(guild.id)
await ui.set_done()
await ui.wait()
# ---------- Commands ----------
@cmds.hybrid_command(name=_p('cmd:ranks', "ranks"))
async def ranks_cmd(self, ctx: LionContext):

View File

@@ -2,3 +2,4 @@ from .editor import RankEditor
from .preview import RankPreviewUI
from .overview import RankOverviewUI
from .config import RankConfigUI
from .refresh import RankRefreshUI

View File

@@ -98,7 +98,8 @@ class RankOverviewUI(MessageUI):
Refresh the current ranks,
ensuring that all members have the correct rank.
"""
await press.response.send_message("Not Implemented Yet")
cog = self.bot.get_cog('RankCog')
await cog.interactive_rank_refresh(press, self.guild)
async def refresh_button_refresh(self):
self.refresh_button.label = self.bot.translator.t(_p(

View File

@@ -0,0 +1,259 @@
from typing import Optional
import asyncio
import discord
from discord.ui.select import select, Select, SelectOption, RoleSelect
from discord.ui.button import button, Button, ButtonStyle
from meta import conf, LionBot
from meta.logger import log_wrap
from core.data import RankType
from data import ORDER
from utils.ui import MessageUI
from utils.lib import MessageArgs, utc_now
from babel.translator import ctx_translator
from .. import babel, logger
from ..data import AnyRankData
from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value
from .editor import RankEditor
from .preview import RankPreviewUI
_p = babel._p
class RankRefreshUI(MessageUI):
def __init__(self, bot: LionBot, guild: discord.Guild, **kwargs):
super().__init__(**kwargs)
self.bot = bot
self.guild = guild
self.stage_ranks = None
self.stage_members = None
self.stage_roles = None
self.stage_compute = None
self.to_remove = 0
self.to_add = 0
self.removed = 0
self.added = 0
self.error: Optional[str] = None
self.done = False
self.errors: list[str] = []
self._loop_task: Optional[asyncio.Task] = None
self._wakeup = asyncio.Event()
# ----- API -----
async def set_error(self, error: str):
"""
Set the given error, refresh, and stop.
"""
self.error = error
await self.refresh()
await self.close()
async def set_done(self):
self.done = True
await self.refresh()
await self.close()
def poke(self):
self._wakeup.set()
async def run(self, *args, **kwargs):
await super().run(*args, **kwargs)
self._loop_task = asyncio.create_task(self._refresh_loop(), name='refresh ui loop')
async def cleanup(self):
if self._loop_task and not self._loop_task.done():
self._loop_task.cancel()
await super().cleanup()
def progress_bar(self, value, minimum, maximum, width=10) -> str:
"""
Build a text progress bar representing `value` between `minimum` and `maximum`.
"""
emojis = self.bot.config.emojis
proportion = (value - minimum) / (maximum - minimum)
sections = min(max(int(proportion * width), 0), width)
bar = []
# Starting segment
bar.append(str(emojis.progress_left_empty) if sections == 0 else str(emojis.progress_left_full))
# Full segments up to transition or end
if sections >= 2:
bar.append(str(emojis.progress_middle_full) * (sections - 2))
# Transition, if required
if 1 < sections < width:
bar.append(str(emojis.progress_middle_transition))
# Empty sections up to end
if sections < width:
bar.append(str(emojis.progress_middle_empty) * (width - max(sections, 1) - 1))
# End section
bar.append(str(emojis.progress_right_empty) if sections < width else str(emojis.progress_right_full))
# Join all the sections together and return
return ''.join(bar)
@log_wrap('refresh ui loop')
async def _refresh_loop(self):
while True:
try:
await asyncio.sleep(1)
await self._wakeup.wait()
await self.refresh()
except asyncio.CancelledError:
break
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
t = self.bot.translator.t
errored = bool(self.error)
if errored:
waiting_emoji = self.bot.config.emojis.cancel
title = t(_p(
'ui:refresh_ranks|embed|title:errored',
"Could not refresh the server ranks!"
))
colour = discord.Colour.brand_red()
else:
waiting_emoji = self.bot.config.emojis.loading
if self.done:
title = t(_p(
'ui:refresh_ranks|embed|title:done',
"Rank refresh complete!"
))
colour = discord.Colour.brand_green()
else:
title = t(_p(
'ui:refresh_ranks|embed|title:working',
"Refreshing your server ranks, please wait."
))
colour = discord.Colour.orange()
embed = discord.Embed(
colour=colour,
title=title,
timestamp=utc_now()
)
lines = []
stop_here = False
if not stop_here:
stage = self.stage_ranks
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:ranks',
"**Loading server ranks:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_members
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:members',
"**Loading server members:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_roles
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:roles',
"**Loading rank roles:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
stage = self.stage_compute
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
text = t(_p(
'ui:refresh_ranks|embed|line:compute',
"**Computing correct ranks:** {emoji}"
)).format(emoji=emoji)
lines.append(text)
stop_here = not bool(stage)
if not stop_here:
lines.append("")
if self.to_remove > self.removed and not errored:
# Still have members to remove, show loading bar
name = t(_p(
'ui:refresh_ranks|embed|field:remove|name',
"Removing invalid rank roles from members"
))
value = t(_p(
'ui:refresh_ranks|embed|field:remove|value',
"0 {progress} {total}"
)).format(
progress=self.progress_bar(self.removed, 0, self.to_remove),
total=self.to_remove,
)
embed.add_field(name=name, value=value, inline=False)
else:
emoji = self.bot.config.emojis.tick
text = t(_p(
'ui:refresh_ranks|embed|line:remove',
"**Removed invalid ranks:** {done}/{target}"
)).format(done=self.removed, target=self.to_remove)
lines.append(text)
if self.to_add > self.added and not errored:
# Still have members to add, show loading bar
name = t(_p(
'ui:refresh_ranks|embed|field:add|name',
"Giving members their rank roles"
))
value = t(_p(
'ui:refresh_ranks|embed|field:add|value',
"0 {progress} {total}"
)).format(
progress=self.progress_bar(self.added, 0, self.to_add),
total=self.to_add,
)
embed.add_field(name=name, value=value, inline=False)
else:
emoji = self.bot.config.emojis.tick
text = t(_p(
'ui:refresh_ranks|embed|line:add',
"**Updated member ranks:** {done}/{target}"
)).format(done=self.added, target=self.to_add)
lines.append(text)
embed.description = '\n'.join(lines)
if self.errors:
name = (
'ui:refresh_ranks|embed|field:errors|title',
"Issues"
)
value = '\n'.join(self.errors)
embed.add_field(name=name, value=value, inline=False)
if self.error:
name = (
'ui:refresh_ranks|embed|field:critical|title',
"Critical Error! Cannot complete refresh"
)
embed.add_field(name=name, value=self.error, inline=False)
return MessageArgs(embed=embed)
async def refresh_layout(self):
pass
async def reload(self):
pass

View File

@@ -289,4 +289,52 @@ class TextTrackerData(Registry):
)
return [r['messages'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='msgs_leaderboard_all')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the message count totals for the given guild since the given time.
"""
query = sql.SQL(
"""
SELECT userid, sum(messages) as user_total
FROM text_sessions
WHERE guildid = %s AND start_time >= %s
GROUP BY userid
ORDER BY
"""
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since))
leaderboard = [
(row['userid'], int(row['user_total']))
for row in await cursor.fetchall()
]
return leaderboard
@classmethod
@log_wrap(action='msgs_leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time message count totals for the given guild.
"""
query = sql.SQL(
"""
SELECT userid, sum(messages) as user_total
FROM text_sessions
WHERE guildid = %s
GROUP BY userid
ORDER BY
"""
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid,))
leaderboard = [
(row['userid'], int(row['user_total']))
for row in await cursor.fetchall()
]
return leaderboard
untracked_channels = Table('untracked_text_channels')

View File

@@ -1,10 +1,14 @@
import asyncio
import time
import logging
from meta.errors import SafeCancellation
from cachetools import TTLCache
logger = logging.getLogger()
class BucketFull(Exception):
"""
@@ -129,3 +133,34 @@ class RateLimit:
return await func(ctx, *args, **kwargs)
return wrapper
return decorator
async def limit_concurrency(aws, limit):
"""
Run provided awaitables concurrently,
ensuring that no more than `limit` are running at once.
"""
aws = iter(aws)
aws_ended = False
pending = set()
count = 0
logger.debug("Starting limited concurrency executor")
while pending or not aws_ended:
while len(pending) < limit and not aws_ended:
aw = next(aws, None)
if aw is None:
aws_ended = True
else:
pending.add(asyncio.create_task(aw))
count += 1
if not pending:
break
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
while done:
yield done.pop()
logger.debug(f"Completed {count} tasks")