feat: Implement rank refresh.
This commit is contained in:
@@ -13,12 +13,14 @@ from wards import high_management_ward, high_management_iward
|
||||
from core.data import RankType
|
||||
from utils.ui import ChoicedEnum, Transformed
|
||||
from utils.lib import utc_now, replace_multiple
|
||||
from utils.ratelimits import Bucket, limit_concurrency
|
||||
from utils.data import TemporaryTable
|
||||
|
||||
|
||||
from . import babel, logger
|
||||
from .data import RankData, AnyRankData
|
||||
from .settings import RankSettings
|
||||
from .ui import RankOverviewUI, RankConfigUI
|
||||
from .ui import RankOverviewUI, RankConfigUI, RankRefreshUI
|
||||
from .utils import rank_model_from_type, format_stat_range
|
||||
|
||||
_p = babel._p
|
||||
@@ -158,6 +160,27 @@ class RankCog(LionCog):
|
||||
self._member_ranks[guildid] = cached
|
||||
return cached
|
||||
|
||||
def _get_stats_model(self, rank_type):
|
||||
return {
|
||||
RankType.MESSAGE: self.bot.get_cog('TextTrackerCog').data.TextSessions,
|
||||
RankType.VOICE: self.bot.get_cog('StatsCog').data.VoiceSessionStats,
|
||||
RankType.XP: self.bot.get_cog('StatsCog').data.MemberExp,
|
||||
}[rank_type]
|
||||
|
||||
def _get_rank_model(self, rank_type):
|
||||
return {
|
||||
RankType.MESSAGE: self.data.MsgRank,
|
||||
RankType.VOICE: self.data.VoiceRank,
|
||||
RankType.XP: self.data.XPRank,
|
||||
}[rank_type]
|
||||
|
||||
def _get_rankid_column(self, rank_type):
|
||||
return {
|
||||
RankType.MESSAGE: 'current_msg_rankid',
|
||||
RankType.VOICE: 'current_voice_rankid',
|
||||
RankType.XP: 'current_xp_rankid'
|
||||
}[rank_type]
|
||||
|
||||
async def get_member_rank(self, guildid: int, userid: int) -> SeasonRank:
|
||||
"""
|
||||
Fetch the SeasonRank info for the given member.
|
||||
@@ -409,8 +432,14 @@ class RankCog(LionCog):
|
||||
# TODO: Locking between refresh and individual updates
|
||||
for guildid, userid, duration, guild_xp in session_data:
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guildid)
|
||||
unranked_roleids = set(unranked_role_setting.data)
|
||||
guild = self.bot.get_guild(guildid)
|
||||
member = guild.get_member(userid) if guild else None
|
||||
if not member or member.bot or any (role.id in unranked_roleids for role in member.roles):
|
||||
continue
|
||||
rank_type = lguild.config.get('rank_type').value
|
||||
if rank_type in (RankType.VOICE, RankType.XP):
|
||||
if rank_type in (RankType.VOICE,):
|
||||
if (_members := self._member_ranks.get(guildid, None)) is not None and userid in _members:
|
||||
session_rank = _members[userid]
|
||||
# TODO: Temporary measure
|
||||
@@ -433,6 +462,201 @@ class RankCog(LionCog):
|
||||
async def on_xp_update(self, *xp_data):
|
||||
...
|
||||
|
||||
async def interactive_rank_refresh(self, interaction: discord.Interaction, guild: discord.Guild):
|
||||
"""
|
||||
Interactively update ranks for everyone in the given guild.
|
||||
"""
|
||||
t = self.bot.translator.t
|
||||
if not interaction.response.is_done():
|
||||
await interaction.response.defer(thinking=True, ephemeral=False)
|
||||
ui = RankRefreshUI(self.bot, guild, callerid=interaction.user.id)
|
||||
await ui.run(interaction)
|
||||
|
||||
# Retrieve fresh rank roles
|
||||
ranks = await self.get_guild_ranks(guild.id, refresh=True)
|
||||
ui.stage_ranks = True
|
||||
ui.poke()
|
||||
|
||||
# Ensure guild is chunked
|
||||
if not guild.chunked:
|
||||
members = await guild.chunk()
|
||||
else:
|
||||
members = guild.members
|
||||
ui.stage_members = True
|
||||
ui.poke()
|
||||
|
||||
roles = {rank.roleid: guild.get_role(rank.roleid) for rank in ranks}
|
||||
if not all(roles.values()):
|
||||
error = t(_p(
|
||||
'rank_refresh|error:roles_dne|desc',
|
||||
"Some ranks have invalid or deleted roles! Please remove them first."
|
||||
))
|
||||
await ui.set_error(error)
|
||||
return
|
||||
|
||||
# Check that bot has permission to assign rank roles
|
||||
failing = [role for role in roles.values() if not role.is_assignable()]
|
||||
if failing:
|
||||
error = t(_p(
|
||||
'rank_refresh|error:unassignable_roles|desc',
|
||||
"I have insufficient permissions to assign the following role(s):\n{roles}"
|
||||
)).format(roles='\n'.join(role.mention for role in failing)),
|
||||
await ui.set_error(error)
|
||||
return
|
||||
|
||||
ui.stage_roles = True
|
||||
ui.poke()
|
||||
|
||||
# Now we are certain that all the rank roles exist and are assignable
|
||||
# Compute season start and season leaderboard
|
||||
lguild = await self.bot.core.lions.fetch_guild(guild.id)
|
||||
season_start = lguild.config.get('season_start').value
|
||||
rank_type = lguild.config.get('rank_type').value
|
||||
stats_model = self._get_stats_model(rank_type)
|
||||
if season_start:
|
||||
leaderboard = await stats_model.leaderboard_since(guild.id, season_start)
|
||||
else:
|
||||
leaderboard = await stats_model.leaderboard_all(guild.id)
|
||||
|
||||
# Compile map of correct ranks
|
||||
# Filtering out members who are untracked or not in server
|
||||
unranked_role_setting = await self.bot.get_cog('StatsCog').settings.UnrankedRoles.get(guild.id)
|
||||
unranked_roleids = set(unranked_role_setting.data)
|
||||
true_member_ranks: dict[int, RankData.VoiceRank | RankData.XPRank | RankData.MsgRank] = {}
|
||||
for userid, stat_total in leaderboard:
|
||||
# Check member exists
|
||||
if member := guild.get_member(userid):
|
||||
# Check member does not have unranked roles
|
||||
if not (member.bot or any(role.id in unranked_roleids for role in member.roles)):
|
||||
# Compute member rank
|
||||
rank = next((rank for rank in reversed(ranks) if rank.required <= stat_total), None)
|
||||
if rank is not None:
|
||||
true_member_ranks[userid] = rank
|
||||
|
||||
# Compile maps of member roles that need removal and member roles that need adding
|
||||
to_remove: list[tuple[discord.Member, list[discord.Role]]] = []
|
||||
to_add: list[tuple[discord.Member, discord.Role]] = []
|
||||
for member in members:
|
||||
if member.bot:
|
||||
continue
|
||||
true_rank = true_member_ranks.get(member.id, None)
|
||||
true_roleid = true_rank.roleid if true_rank is not None else None
|
||||
has_true = (true_roleid is None)
|
||||
invalid = []
|
||||
for role in member.roles:
|
||||
if role.id in roles:
|
||||
if not has_true and role.id == true_roleid:
|
||||
has_true = True
|
||||
else:
|
||||
invalid.append(role)
|
||||
if invalid:
|
||||
to_remove.append((member, invalid))
|
||||
if not has_true:
|
||||
to_add.append((member, roles[true_roleid]))
|
||||
|
||||
ui.stage_compute = True
|
||||
ui.to_remove = len(to_remove)
|
||||
ui.to_add = len(to_add)
|
||||
ui.poke()
|
||||
|
||||
# Perform operations
|
||||
# Starting with removals
|
||||
coros = []
|
||||
bucket = Bucket(4, 5)
|
||||
|
||||
for member, roles in to_remove:
|
||||
remove_coro = member.remove_roles(
|
||||
*roles,
|
||||
reason=t(_p(
|
||||
'rank_refresh|remove_roles|audit',
|
||||
"Removing invalid rank role."
|
||||
))
|
||||
)
|
||||
coros.append(bucket.wrapped(remove_coro))
|
||||
|
||||
index = 0
|
||||
async for task in limit_concurrency(coros, 5):
|
||||
try:
|
||||
await task
|
||||
index += 1
|
||||
ui.poke()
|
||||
except discord.HTTPException:
|
||||
error = t(_p(
|
||||
'rank_refresh|remove_roles|small_error',
|
||||
"*Could not remove ranks from {member}*"
|
||||
)).format(member=to_remove[index][0].mention)
|
||||
self.ui.errors.append(error)
|
||||
if len(self.ui.errors) > 10:
|
||||
await ui.set_error(
|
||||
t(_p(
|
||||
'rank_refresh|remove_roles|error:too_many_issues',
|
||||
"Too many issues occurred while removing ranks! "
|
||||
"Please check my permissions and try again in a few minutes."
|
||||
))
|
||||
)
|
||||
return
|
||||
ui.removed += 1
|
||||
ui.poke()
|
||||
|
||||
coros = []
|
||||
for member, role in to_add:
|
||||
add_coro = member.add_roles(
|
||||
role,
|
||||
reason=t(_p(
|
||||
'rank_refresh|add_roles|audit',
|
||||
"Adding rank role from refresh"
|
||||
))
|
||||
)
|
||||
coros.append(bucket.wrapped(add_coro))
|
||||
|
||||
index = 0
|
||||
async for task in limit_concurrency(coros, 5):
|
||||
try:
|
||||
await task
|
||||
index += 1
|
||||
ui.poke()
|
||||
except discord.HTTPException:
|
||||
error = t(_p(
|
||||
'rank_refresh|add_roles|small_error',
|
||||
"*Could not add {role} to {member}*"
|
||||
)).format(member=to_add[index][0].mention, role=to_add[index][1].mention)
|
||||
self.ui.errors.append(error)
|
||||
if len(self.ui.errors) > 10:
|
||||
await ui.set_error(
|
||||
t(_p(
|
||||
'rank_refresh|add_roles|error:too_many_issues',
|
||||
"Too many issues occurred while adding ranks! "
|
||||
"Please check my permissions and try again in a few minutes."
|
||||
))
|
||||
)
|
||||
return
|
||||
ui.added += 1
|
||||
ui.poke()
|
||||
|
||||
# Save correct member ranks and given roles to data
|
||||
# First clear the member rank data entirely
|
||||
await self.data.MemberRank.table.delete_where(guildid=guild.id)
|
||||
column = self._get_rankid_column(rank_type)
|
||||
tmptable = TemporaryTable(
|
||||
'_gid', '_uid', '_rankid', '_roleid',
|
||||
types=('BIGINT', 'BIGINT', 'BIGINT', 'BIGINT')
|
||||
)
|
||||
tmptable.values = [
|
||||
(guild.id, memberid, rank.rankid, rank.roleid)
|
||||
for memberid, rank in true_member_ranks.items()
|
||||
]
|
||||
if tmptable.values:
|
||||
await self.data.MemberRank.table.update_where(
|
||||
guildid=tmptable['_gid'],
|
||||
userid=tmptable['_uid']
|
||||
).set(
|
||||
**{column: tmptable['_rankid'], 'last_roleid': tmptable['_roleid']}
|
||||
).from_expr(tmptable)
|
||||
|
||||
self.flush_guild_ranks(guild.id)
|
||||
await ui.set_done()
|
||||
await ui.wait()
|
||||
|
||||
# ---------- Commands ----------
|
||||
@cmds.hybrid_command(name=_p('cmd:ranks', "ranks"))
|
||||
async def ranks_cmd(self, ctx: LionContext):
|
||||
|
||||
@@ -2,3 +2,4 @@ from .editor import RankEditor
|
||||
from .preview import RankPreviewUI
|
||||
from .overview import RankOverviewUI
|
||||
from .config import RankConfigUI
|
||||
from .refresh import RankRefreshUI
|
||||
|
||||
@@ -98,7 +98,8 @@ class RankOverviewUI(MessageUI):
|
||||
Refresh the current ranks,
|
||||
ensuring that all members have the correct rank.
|
||||
"""
|
||||
await press.response.send_message("Not Implemented Yet")
|
||||
cog = self.bot.get_cog('RankCog')
|
||||
await cog.interactive_rank_refresh(press, self.guild)
|
||||
|
||||
async def refresh_button_refresh(self):
|
||||
self.refresh_button.label = self.bot.translator.t(_p(
|
||||
|
||||
259
src/modules/ranks/ui/refresh.py
Normal file
259
src/modules/ranks/ui/refresh.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.ui.select import select, Select, SelectOption, RoleSelect
|
||||
from discord.ui.button import button, Button, ButtonStyle
|
||||
|
||||
from meta import conf, LionBot
|
||||
from meta.logger import log_wrap
|
||||
from core.data import RankType
|
||||
from data import ORDER
|
||||
|
||||
from utils.ui import MessageUI
|
||||
from utils.lib import MessageArgs, utc_now
|
||||
from babel.translator import ctx_translator
|
||||
|
||||
from .. import babel, logger
|
||||
from ..data import AnyRankData
|
||||
from ..utils import rank_model_from_type, format_stat_range, stat_data_to_value
|
||||
from .editor import RankEditor
|
||||
from .preview import RankPreviewUI
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class RankRefreshUI(MessageUI):
|
||||
def __init__(self, bot: LionBot, guild: discord.Guild, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bot = bot
|
||||
self.guild = guild
|
||||
|
||||
self.stage_ranks = None
|
||||
self.stage_members = None
|
||||
self.stage_roles = None
|
||||
self.stage_compute = None
|
||||
|
||||
self.to_remove = 0
|
||||
self.to_add = 0
|
||||
self.removed = 0
|
||||
self.added = 0
|
||||
|
||||
self.error: Optional[str] = None
|
||||
self.done = False
|
||||
|
||||
self.errors: list[str] = []
|
||||
|
||||
self._loop_task: Optional[asyncio.Task] = None
|
||||
self._wakeup = asyncio.Event()
|
||||
|
||||
# ----- API -----
|
||||
async def set_error(self, error: str):
|
||||
"""
|
||||
Set the given error, refresh, and stop.
|
||||
"""
|
||||
self.error = error
|
||||
await self.refresh()
|
||||
await self.close()
|
||||
|
||||
async def set_done(self):
|
||||
self.done = True
|
||||
await self.refresh()
|
||||
await self.close()
|
||||
|
||||
def poke(self):
|
||||
self._wakeup.set()
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
await super().run(*args, **kwargs)
|
||||
self._loop_task = asyncio.create_task(self._refresh_loop(), name='refresh ui loop')
|
||||
|
||||
async def cleanup(self):
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
self._loop_task.cancel()
|
||||
await super().cleanup()
|
||||
|
||||
def progress_bar(self, value, minimum, maximum, width=10) -> str:
|
||||
"""
|
||||
Build a text progress bar representing `value` between `minimum` and `maximum`.
|
||||
"""
|
||||
emojis = self.bot.config.emojis
|
||||
|
||||
proportion = (value - minimum) / (maximum - minimum)
|
||||
sections = min(max(int(proportion * width), 0), width)
|
||||
|
||||
bar = []
|
||||
# Starting segment
|
||||
bar.append(str(emojis.progress_left_empty) if sections == 0 else str(emojis.progress_left_full))
|
||||
|
||||
# Full segments up to transition or end
|
||||
if sections >= 2:
|
||||
bar.append(str(emojis.progress_middle_full) * (sections - 2))
|
||||
|
||||
# Transition, if required
|
||||
if 1 < sections < width:
|
||||
bar.append(str(emojis.progress_middle_transition))
|
||||
|
||||
# Empty sections up to end
|
||||
if sections < width:
|
||||
bar.append(str(emojis.progress_middle_empty) * (width - max(sections, 1) - 1))
|
||||
|
||||
# End section
|
||||
bar.append(str(emojis.progress_right_empty) if sections < width else str(emojis.progress_right_full))
|
||||
|
||||
# Join all the sections together and return
|
||||
return ''.join(bar)
|
||||
|
||||
@log_wrap('refresh ui loop')
|
||||
async def _refresh_loop(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
await self._wakeup.wait()
|
||||
await self.refresh()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# ----- UI Flow -----
|
||||
async def make_message(self) -> MessageArgs:
|
||||
t = self.bot.translator.t
|
||||
errored = bool(self.error)
|
||||
if errored:
|
||||
waiting_emoji = self.bot.config.emojis.cancel
|
||||
title = t(_p(
|
||||
'ui:refresh_ranks|embed|title:errored',
|
||||
"Could not refresh the server ranks!"
|
||||
))
|
||||
colour = discord.Colour.brand_red()
|
||||
else:
|
||||
waiting_emoji = self.bot.config.emojis.loading
|
||||
if self.done:
|
||||
title = t(_p(
|
||||
'ui:refresh_ranks|embed|title:done',
|
||||
"Rank refresh complete!"
|
||||
))
|
||||
colour = discord.Colour.brand_green()
|
||||
else:
|
||||
title = t(_p(
|
||||
'ui:refresh_ranks|embed|title:working',
|
||||
"Refreshing your server ranks, please wait."
|
||||
))
|
||||
colour = discord.Colour.orange()
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=colour,
|
||||
title=title,
|
||||
timestamp=utc_now()
|
||||
)
|
||||
|
||||
lines = []
|
||||
stop_here = False
|
||||
|
||||
if not stop_here:
|
||||
stage = self.stage_ranks
|
||||
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:ranks',
|
||||
"**Loading server ranks:** {emoji}"
|
||||
)).format(emoji=emoji)
|
||||
lines.append(text)
|
||||
stop_here = not bool(stage)
|
||||
|
||||
if not stop_here:
|
||||
stage = self.stage_members
|
||||
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:members',
|
||||
"**Loading server members:** {emoji}"
|
||||
)).format(emoji=emoji)
|
||||
lines.append(text)
|
||||
stop_here = not bool(stage)
|
||||
|
||||
if not stop_here:
|
||||
stage = self.stage_roles
|
||||
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:roles',
|
||||
"**Loading rank roles:** {emoji}"
|
||||
)).format(emoji=emoji)
|
||||
lines.append(text)
|
||||
stop_here = not bool(stage)
|
||||
|
||||
if not stop_here:
|
||||
stage = self.stage_compute
|
||||
emoji = self.bot.config.emojis.tick if stage else waiting_emoji
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:compute',
|
||||
"**Computing correct ranks:** {emoji}"
|
||||
)).format(emoji=emoji)
|
||||
lines.append(text)
|
||||
stop_here = not bool(stage)
|
||||
|
||||
if not stop_here:
|
||||
lines.append("")
|
||||
if self.to_remove > self.removed and not errored:
|
||||
# Still have members to remove, show loading bar
|
||||
name = t(_p(
|
||||
'ui:refresh_ranks|embed|field:remove|name',
|
||||
"Removing invalid rank roles from members"
|
||||
))
|
||||
value = t(_p(
|
||||
'ui:refresh_ranks|embed|field:remove|value',
|
||||
"0 {progress} {total}"
|
||||
)).format(
|
||||
progress=self.progress_bar(self.removed, 0, self.to_remove),
|
||||
total=self.to_remove,
|
||||
)
|
||||
embed.add_field(name=name, value=value, inline=False)
|
||||
else:
|
||||
emoji = self.bot.config.emojis.tick
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:remove',
|
||||
"**Removed invalid ranks:** {done}/{target}"
|
||||
)).format(done=self.removed, target=self.to_remove)
|
||||
lines.append(text)
|
||||
|
||||
if self.to_add > self.added and not errored:
|
||||
# Still have members to add, show loading bar
|
||||
name = t(_p(
|
||||
'ui:refresh_ranks|embed|field:add|name',
|
||||
"Giving members their rank roles"
|
||||
))
|
||||
value = t(_p(
|
||||
'ui:refresh_ranks|embed|field:add|value',
|
||||
"0 {progress} {total}"
|
||||
)).format(
|
||||
progress=self.progress_bar(self.added, 0, self.to_add),
|
||||
total=self.to_add,
|
||||
)
|
||||
embed.add_field(name=name, value=value, inline=False)
|
||||
else:
|
||||
emoji = self.bot.config.emojis.tick
|
||||
text = t(_p(
|
||||
'ui:refresh_ranks|embed|line:add',
|
||||
"**Updated member ranks:** {done}/{target}"
|
||||
)).format(done=self.added, target=self.to_add)
|
||||
lines.append(text)
|
||||
|
||||
embed.description = '\n'.join(lines)
|
||||
if self.errors:
|
||||
name = (
|
||||
'ui:refresh_ranks|embed|field:errors|title',
|
||||
"Issues"
|
||||
)
|
||||
value = '\n'.join(self.errors)
|
||||
embed.add_field(name=name, value=value, inline=False)
|
||||
if self.error:
|
||||
name = (
|
||||
'ui:refresh_ranks|embed|field:critical|title',
|
||||
"Critical Error! Cannot complete refresh"
|
||||
)
|
||||
embed.add_field(name=name, value=self.error, inline=False)
|
||||
|
||||
return MessageArgs(embed=embed)
|
||||
|
||||
async def refresh_layout(self):
|
||||
pass
|
||||
|
||||
async def reload(self):
|
||||
pass
|
||||
@@ -288,5 +288,53 @@ class TextTrackerData(Registry):
|
||||
tuple(chain((userid, guildid), points))
|
||||
)
|
||||
return [r['messages'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='msgs_leaderboard_all')
|
||||
async def leaderboard_since(cls, guildid: int, since):
|
||||
"""
|
||||
Return the message count totals for the given guild since the given time.
|
||||
"""
|
||||
query = sql.SQL(
|
||||
"""
|
||||
SELECT userid, sum(messages) as user_total
|
||||
FROM text_sessions
|
||||
WHERE guildid = %s AND start_time >= %s
|
||||
GROUP BY userid
|
||||
ORDER BY
|
||||
"""
|
||||
)
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(query, (guildid, since))
|
||||
leaderboard = [
|
||||
(row['userid'], int(row['user_total']))
|
||||
for row in await cursor.fetchall()
|
||||
]
|
||||
return leaderboard
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='msgs_leaderboard_all')
|
||||
async def leaderboard_all(cls, guildid: int):
|
||||
"""
|
||||
Return the all-time message count totals for the given guild.
|
||||
"""
|
||||
query = sql.SQL(
|
||||
"""
|
||||
SELECT userid, sum(messages) as user_total
|
||||
FROM text_sessions
|
||||
WHERE guildid = %s
|
||||
GROUP BY userid
|
||||
ORDER BY
|
||||
"""
|
||||
)
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(query, (guildid,))
|
||||
leaderboard = [
|
||||
(row['userid'], int(row['user_total']))
|
||||
for row in await cursor.fetchall()
|
||||
]
|
||||
return leaderboard
|
||||
|
||||
untracked_channels = Table('untracked_text_channels')
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
||||
from meta.errors import SafeCancellation
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
"""
|
||||
@@ -129,3 +133,34 @@ class RateLimit:
|
||||
return await func(ctx, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
async def limit_concurrency(aws, limit):
|
||||
"""
|
||||
Run provided awaitables concurrently,
|
||||
ensuring that no more than `limit` are running at once.
|
||||
"""
|
||||
aws = iter(aws)
|
||||
aws_ended = False
|
||||
pending = set()
|
||||
count = 0
|
||||
logger.debug("Starting limited concurrency executor")
|
||||
|
||||
while pending or not aws_ended:
|
||||
while len(pending) < limit and not aws_ended:
|
||||
aw = next(aws, None)
|
||||
if aw is None:
|
||||
aws_ended = True
|
||||
else:
|
||||
pending.add(asyncio.create_task(aw))
|
||||
count += 1
|
||||
|
||||
if not pending:
|
||||
break
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
pending, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
while done:
|
||||
yield done.pop()
|
||||
logger.debug(f"Completed {count} tasks")
|
||||
|
||||
Reference in New Issue
Block a user