From 9b2af56d64883152345a595714bc381ac5fa6b69 Mon Sep 17 00:00:00 2001 From: Conatum Date: Thu, 24 Aug 2023 14:30:03 +0300 Subject: [PATCH] feat: Implement rank refresh. --- src/modules/ranks/cog.py | 228 ++++++++++++++++++++++++++- src/modules/ranks/ui/__init__.py | 1 + src/modules/ranks/ui/overview.py | 3 +- src/modules/ranks/ui/refresh.py | 259 +++++++++++++++++++++++++++++++ src/tracking/text/data.py | 48 ++++++ src/utils/ratelimits.py | 35 +++++ 6 files changed, 571 insertions(+), 3 deletions(-) create mode 100644 src/modules/ranks/ui/refresh.py diff --git a/src/modules/ranks/cog.py b/src/modules/ranks/cog.py index f8b48f38..b0633988 100644 --- a/src/modules/ranks/cog.py +++ b/src/modules/ranks/cog.py @@ -13,12 +13,14 @@ from wards import high_management_ward, high_management_iward from core.data import RankType from utils.ui import ChoicedEnum, Transformed from utils.lib import utc_now, replace_multiple +from utils.ratelimits import Bucket, limit_concurrency +from utils.data import TemporaryTable from . import babel, logger from .data import RankData, AnyRankData from .settings import RankSettings -from .ui import RankOverviewUI, RankConfigUI +from .ui import RankOverviewUI, RankConfigUI, RankRefreshUI from .utils import rank_model_from_type, format_stat_range _p = babel._p @@ -158,6 +160,27 @@ class RankCog(LionCog): self._member_ranks[guildid] = cached return cached + def _get_stats_model(self, rank_type): + return { + RankType.MESSAGE: self.bot.get_cog('TextTrackerCog').data.TextSessions, + RankType.VOICE: self.bot.get_cog('StatsCog').data.VoiceSessionStats, + RankType.XP: self.bot.get_cog('StatsCog').data.MemberExp, + }[rank_type] + + def _get_rank_model(self, rank_type): + return { + RankType.MESSAGE: self.data.MsgRank, + RankType.VOICE: self.data.VoiceRank, + RankType.XP: self.data.XPRank, + }[rank_type] + + def _get_rankid_column(self, rank_type): + return { + RankType.MESSAGE: 'current_msg_rankid', + RankType.VOICE: 'current_voice_rankid', + RankType.XP: 'current_xp_rankid' + }[rank_type] + async def get_member_rank(self, guildid: int, userid: int) -> SeasonRank: """ Fetch the SeasonRank info for the given member. @@ -409,8 +432,14 @@ class RankCog(LionCog): # TODO: Locking between refresh and individual updates for guildid, userid, duration, guild_xp in session_data: lguild = await self.bot.core.lions.fetch_guild(guildid) + unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid) + unranked_roleids = set(unranked_role_setting.data) + guild = self.bot.get_guild(guildid) + member = guild.get_member(userid) if guild else None + if not member or member.bot or any (role.id in unranked_roleids for role in member.roles): + continue rank_type = lguild.config.get('rank_type').value - if rank_type in (RankType.VOICE, RankType.XP): + if rank_type in (RankType.VOICE,): if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members: session_rank = _members[userid] # TODO: Temporary measure @@ -433,6 +462,201 @@ class RankCog(LionCog): async def on_xp_update(self, *xp_data): ... + async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild): + """ + Interactively update ranks for everyone in the given guild. + """ + t = self.bot.translator.t + if not interaction.response.is_done(): + await interaction.response.defer(thinking=True, ephemeral=False) + ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id) + await ui.run(interaction) + + # Retrieve fresh rank roles + ranks = await self.get_guild_ranks(guild.id, refresh=True) + ui.stage_ranks = True + ui.poke() + + # Ensure guild is chunked + if not guild.chunked: + members = await guild.chunk() + else: + members = guild.members + ui.stage_members = True + ui.poke() + + roles = {rank.roleid: guild.get_role(rank.roleid) for rank in ranks} + if not all(roles.values()): + error = t(_p( + 'rank_refresh|error:roles_dne|desc', + "Some ranks have invalid or deleted roles! Please remove them first." + )) + await ui.set_error(error) + return + + # Check that bot has permission to assign rank roles + failing = [role for role in roles.values() if not role.is_assignable()] + if failing: + error = t(_p( + 'rank_refresh|error:unassignable_roles|desc', + "I have insufficient permissions to assign the following role(s):\n{roles}" + )).format(roles='\n'.join(role.mention for role in failing)), + await ui.set_error(error) + return + + ui.stage_roles = True + ui.poke() + + # Now we are certain that all the rank roles exist and are assignable + # Compute season start and season leaderboard + lguild = await self.bot.core.lions.fetch_guild(guild.id) + season_start = lguild.config.get('season_start').value + rank_type = lguild.config.get('rank_type').value + stats_model = self._get_stats_model(rank_type) + if season_start: + leaderboard = await stats_model.leaderboard_since(guild.id, season_start) + else: + leaderboard = await stats_model.leaderboard_all(guild.id) + + # Compile map of correct ranks + # Filtering out members who are untracked or not in server + unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guild.id) + unranked_roleids = set(unranked_role_setting.data) + true_member_ranks: dict[int, RankData.VoiceRank | RankData.XPRank | RankData.MsgRank] = {} + for userid, stat_total in leaderboard: + # Check member exists + if member := guild.get_member(userid): + # Check member does not have unranked roles + if not (member.bot or any(role.id in unranked_roleids for role in member.roles)): + # Compute member rank + rank = next((rank for rank in reversed(ranks) if rank.required <= stat_total), None) + if rank is not None: + true_member_ranks[userid] = rank + + # Compile maps of member roles that need removal and member roles that need adding + to_remove: list[tuple[discord.Member, list[discord.Role]]] = [] + to_add: list[tuple[discord.Member, discord.Role]] = [] + for member in members: + if member.bot: + continue + true_rank = true_member_ranks.get(member.id, None) + true_roleid = true_rank.roleid if true_rank is not None else None + has_true = (true_roleid is None) + invalid = [] + for role in member.roles: + if role.id in roles: + if not has_true and role.id == true_roleid: + has_true = True + else: + invalid.append(role) + if invalid: + to_remove.append((member, invalid)) + if not has_true: + to_add.append((member, roles[true_roleid])) + + ui.stage_compute = True + ui.to_remove = len(to_remove) + ui.to_add = len(to_add) + ui.poke() + + # Perform operations + # Starting with removals + coros = [] + bucket = Bucket(4, 5) + + for member, roles in to_remove: + remove_coro = member.remove_roles( + *roles, + reason=t(_p( + 'rank_refresh|remove_roles|audit', + "Removing invalid rank role." + )) + ) + coros.append(bucket.wrapped(remove_coro)) + + index = 0 + async for task in limit_concurrency(coros, 5): + try: + await task + index += 1 + ui.poke() + except discord.HTTPException: + error = t(_p( + 'rank_refresh|remove_roles|small_error', + "*Could not remove ranks from {member}*" + )).format(member=to_remove[index][0].mention) + self.ui.errors.append(error) + if len(self.ui.errors) > 10: + await ui.set_error( + t(_p( + 'rank_refresh|remove_roles|error:too_many_issues', + "Too many issues occurred while removing ranks! " + "Please check my permissions and try again in a few minutes." + )) + ) + return + ui.removed += 1 + ui.poke() + + coros = [] + for member, role in to_add: + add_coro = member.add_roles( + role, + reason=t(_p( + 'rank_refresh|add_roles|audit', + "Adding rank role from refresh" + )) + ) + coros.append(bucket.wrapped(add_coro)) + + index = 0 + async for task in limit_concurrency(coros, 5): + try: + await task + index += 1 + ui.poke() + except discord.HTTPException: + error = t(_p( + 'rank_refresh|add_roles|small_error', + "*Could not add {role} to {member}*" + )).format(member=to_add[index][0].mention, role=to_add[index][1].mention) + self.ui.errors.append(error) + if len(self.ui.errors) > 10: + await ui.set_error( + t(_p( + 'rank_refresh|add_roles|error:too_many_issues', + "Too many issues occurred while adding ranks! " + "Please check my permissions and try again in a few minutes." + )) + ) + return + ui.added += 1 + ui.poke() + + # Save correct member ranks and given roles to data + # First clear the member rank data entirely + await self.data.MemberRank.table.delete_where(guildid=guild.id) + column = self._get_rankid_column(rank_type) + tmptable = TemporaryTable( + '_gid', '_uid', '_rankid', '_roleid', + types=('BIGINT', 'BIGINT', 'BIGINT', 'BIGINT') + ) + tmptable.values = [ + (guild.id, memberid, rank.rankid, rank.roleid) + for memberid, rank in true_member_ranks.items() + ] + if tmptable.values: + await self.data.MemberRank.table.update_where( + guildid=tmptable['_gid'], + userid=tmptable['_uid'] + ).set( + **{column: tmptable['_rankid'], 'last_roleid': tmptable['_roleid']} + ).from_expr(tmptable) + + self.flush_guild_ranks(guild.id) + await ui.set_done() + await ui.wait() + # ---------- Commands ---------- @cmds.hybrid_command(name=_p('cmd:ranks', "ranks")) async def ranks_cmd(self, ctx: LionContext): diff --git a/src/modules/ranks/ui/__init__.py b/src/modules/ranks/ui/__init__.py index a57569f7..b78874de 100644 --- a/src/modules/ranks/ui/__init__.py +++ b/src/modules/ranks/ui/__init__.py @@ -2,3 +2,4 @@ from .editor import RankEditor from .preview import RankPreviewUI from .overview import RankOverviewUI from .config import RankConfigUI +from .refresh import RankRefreshUI diff --git a/src/modules/ranks/ui/overview.py b/src/modules/ranks/ui/overview.py index 0ceb380a..a07a27d2 100644 --- a/src/modules/ranks/ui/overview.py +++ b/src/modules/ranks/ui/overview.py @@ -98,7 +98,8 @@ class RankOverviewUI(MessageUI): Refresh the current ranks, ensuring that all members have the correct rank. """ - await press.response.send_message("Not Implemented Yet") + cog = self.bot.get_cog('RankCog') + await cog.interactive_rank_refresh(press, self.guild) async def refresh_button_refresh(self): self.refresh_button.label = self.bot.translator.t(_p( diff --git a/src/modules/ranks/ui/refresh.py b/src/modules/ranks/ui/refresh.py new file mode 100644 index 00000000..9289c345 --- /dev/null +++ b/src/modules/ranks/ui/refresh.py @@ -0,0 +1,259 @@ +from typing import Optional +import asyncio + +import discord +from discord.ui.select import select, Select, SelectOption, RoleSelect +from discord.ui.button import button, Button, ButtonStyle + +from meta import conf, LionBot +from meta.logger import log_wrap +from core.data import RankType +from data import ORDER + +from utils.ui import MessageUI +from utils.lib import MessageArgs, utc_now +from babel.translator import ctx_translator + +from .. import babel, logger +from ..data import AnyRankData +from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value +from .editor import RankEditor +from .preview import RankPreviewUI + +_p = babel._p + + +class RankRefreshUI(MessageUI): + def __init__(self, bot: LionBot, guild: discord.Guild, **kwargs): + super().__init__(**kwargs) + self.bot = bot + self.guild = guild + + self.stage_ranks = None + self.stage_members = None + self.stage_roles = None + self.stage_compute = None + + self.to_remove = 0 + self.to_add = 0 + self.removed = 0 + self.added = 0 + + self.error: Optional[str] = None + self.done = False + + self.errors: list[str] = [] + + self._loop_task: Optional[asyncio.Task] = None + self._wakeup = asyncio.Event() + + # ----- API ----- + async def set_error(self, error: str): + """ + Set the given error, refresh, and stop. + """ + self.error = error + await self.refresh() + await self.close() + + async def set_done(self): + self.done = True + await self.refresh() + await self.close() + + def poke(self): + self._wakeup.set() + + async def run(self, *args, **kwargs): + await super().run(*args, **kwargs) + self._loop_task = asyncio.create_task(self._refresh_loop(), name='refresh ui loop') + + async def cleanup(self): + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + await super().cleanup() + + def progress_bar(self, value, minimum, maximum, width=10) -> str: + """ + Build a text progress bar representing `value` between `minimum` and `maximum`. + """ + emojis = self.bot.config.emojis + + proportion = (value - minimum) / (maximum - minimum) + sections = min(max(int(proportion * width), 0), width) + + bar = [] + # Starting segment + bar.append(str(emojis.progress_left_empty) if sections == 0 else str(emojis.progress_left_full)) + + # Full segments up to transition or end + if sections >= 2: + bar.append(str(emojis.progress_middle_full) * (sections - 2)) + + # Transition, if required + if 1 < sections < width: + bar.append(str(emojis.progress_middle_transition)) + + # Empty sections up to end + if sections < width: + bar.append(str(emojis.progress_middle_empty) * (width - max(sections, 1) - 1)) + + # End section + bar.append(str(emojis.progress_right_empty) if sections < width else str(emojis.progress_right_full)) + + # Join all the sections together and return + return ''.join(bar) + + @log_wrap('refresh ui loop') + async def _refresh_loop(self): + while True: + try: + await asyncio.sleep(1) + await self._wakeup.wait() + await self.refresh() + except asyncio.CancelledError: + break + + # ----- UI Flow ----- + async def make_message(self) -> MessageArgs: + t = self.bot.translator.t + errored = bool(self.error) + if errored: + waiting_emoji = self.bot.config.emojis.cancel + title = t(_p( + 'ui:refresh_ranks|embed|title:errored', + "Could not refresh the server ranks!" + )) + colour = discord.Colour.brand_red() + else: + waiting_emoji = self.bot.config.emojis.loading + if self.done: + title = t(_p( + 'ui:refresh_ranks|embed|title:done', + "Rank refresh complete!" + )) + colour = discord.Colour.brand_green() + else: + title = t(_p( + 'ui:refresh_ranks|embed|title:working', + "Refreshing your server ranks, please wait." + )) + colour = discord.Colour.orange() + + embed = discord.Embed( + colour=colour, + title=title, + timestamp=utc_now() + ) + + lines = [] + stop_here = False + + if not stop_here: + stage = self.stage_ranks + emoji = self.bot.config.emojis.tick if stage else waiting_emoji + text = t(_p( + 'ui:refresh_ranks|embed|line:ranks', + "**Loading server ranks:** {emoji}" + )).format(emoji=emoji) + lines.append(text) + stop_here = not bool(stage) + + if not stop_here: + stage = self.stage_members + emoji = self.bot.config.emojis.tick if stage else waiting_emoji + text = t(_p( + 'ui:refresh_ranks|embed|line:members', + "**Loading server members:** {emoji}" + )).format(emoji=emoji) + lines.append(text) + stop_here = not bool(stage) + + if not stop_here: + stage = self.stage_roles + emoji = self.bot.config.emojis.tick if stage else waiting_emoji + text = t(_p( + 'ui:refresh_ranks|embed|line:roles', + "**Loading rank roles:** {emoji}" + )).format(emoji=emoji) + lines.append(text) + stop_here = not bool(stage) + + if not stop_here: + stage = self.stage_compute + emoji = self.bot.config.emojis.tick if stage else waiting_emoji + text = t(_p( + 'ui:refresh_ranks|embed|line:compute', + "**Computing correct ranks:** {emoji}" + )).format(emoji=emoji) + lines.append(text) + stop_here = not bool(stage) + + if not stop_here: + lines.append("") + if self.to_remove > self.removed and not errored: + # Still have members to remove, show loading bar + name = t(_p( + 'ui:refresh_ranks|embed|field:remove|name', + "Removing invalid rank roles from members" + )) + value = t(_p( + 'ui:refresh_ranks|embed|field:remove|value', + "0 {progress} {total}" + )).format( + progress=self.progress_bar(self.removed, 0, self.to_remove), + total=self.to_remove, + ) + embed.add_field(name=name, value=value, inline=False) + else: + emoji = self.bot.config.emojis.tick + text = t(_p( + 'ui:refresh_ranks|embed|line:remove', + "**Removed invalid ranks:** {done}/{target}" + )).format(done=self.removed, target=self.to_remove) + lines.append(text) + + if self.to_add > self.added and not errored: + # Still have members to add, show loading bar + name = t(_p( + 'ui:refresh_ranks|embed|field:add|name', + "Giving members their rank roles" + )) + value = t(_p( + 'ui:refresh_ranks|embed|field:add|value', + "0 {progress} {total}" + )).format( + progress=self.progress_bar(self.added, 0, self.to_add), + total=self.to_add, + ) + embed.add_field(name=name, value=value, inline=False) + else: + emoji = self.bot.config.emojis.tick + text = t(_p( + 'ui:refresh_ranks|embed|line:add', + "**Updated member ranks:** {done}/{target}" + )).format(done=self.added, target=self.to_add) + lines.append(text) + + embed.description = '\n'.join(lines) + if self.errors: + name = ( + 'ui:refresh_ranks|embed|field:errors|title', + "Issues" + ) + value = '\n'.join(self.errors) + embed.add_field(name=name, value=value, inline=False) + if self.error: + name = ( + 'ui:refresh_ranks|embed|field:critical|title', + "Critical Error! Cannot complete refresh" + ) + embed.add_field(name=name, value=self.error, inline=False) + + return MessageArgs(embed=embed) + + async def refresh_layout(self): + pass + + async def reload(self): + pass diff --git a/src/tracking/text/data.py b/src/tracking/text/data.py index d51ca8a3..dda4c65d 100644 --- a/src/tracking/text/data.py +++ b/src/tracking/text/data.py @@ -288,5 +288,53 @@ class TextTrackerData(Registry): tuple(chain((userid, guildid), points)) ) return [r['messages'] or 0 for r in await cursor.fetchall()] + + @classmethod + @log_wrap(action='msgs_leaderboard_all') + async def leaderboard_since(cls, guildid: int, since): + """ + Return the message count totals for the given guild since the given time. + """ + query = sql.SQL( + """ + SELECT userid, sum(messages) as user_total + FROM text_sessions + WHERE guildid = %s AND start_time >= %s + GROUP BY userid + ORDER BY + """ + ) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, (guildid, since)) + leaderboard = [ + (row['userid'], int(row['user_total'])) + for row in await cursor.fetchall() + ] + return leaderboard + + @classmethod + @log_wrap(action='msgs_leaderboard_all') + async def leaderboard_all(cls, guildid: int): + """ + Return the all-time message count totals for the given guild. + """ + query = sql.SQL( + """ + SELECT userid, sum(messages) as user_total + FROM text_sessions + WHERE guildid = %s + GROUP BY userid + ORDER BY + """ + ) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, (guildid,)) + leaderboard = [ + (row['userid'], int(row['user_total'])) + for row in await cursor.fetchall() + ] + return leaderboard untracked_channels = Table('untracked_text_channels') diff --git a/src/utils/ratelimits.py b/src/utils/ratelimits.py index 9dbfdc81..21c75160 100644 --- a/src/utils/ratelimits.py +++ b/src/utils/ratelimits.py @@ -1,10 +1,14 @@ import asyncio import time +import logging from meta.errors import SafeCancellation from cachetools import TTLCache +logger = logging.getLogger() + + class BucketFull(Exception): """ @@ -129,3 +133,34 @@ class RateLimit: return await func(ctx, *args, **kwargs) return wrapper return decorator + + +async def limit_concurrency(aws, limit): + """ + Run provided awaitables concurrently, + ensuring that no more than `limit` are running at once. + """ + aws = iter(aws) + aws_ended = False + pending = set() + count = 0 + logger.debug("Starting limited concurrency executor") + + while pending or not aws_ended: + while len(pending) < limit and not aws_ended: + aw = next(aws, None) + if aw is None: + aws_ended = True + else: + pending.add(asyncio.create_task(aw)) + count += 1 + + if not pending: + break + + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + while done: + yield done.pop() + logger.debug(f"Completed {count} tasks")