From 0dd5213f13a8ca5793ae9ff5bbbf498e5e94bee4 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 19:07:28 +0200 Subject: [PATCH] sharding (accountability): Adapt for sharding. Filter initially loaded accountability guilds. Filter timeslots loaded in `open_next`. Reload members and overwrites on slot start. --- bot/modules/accountability/TimeSlot.py | 29 +++++++++++++++++++++++- bot/modules/accountability/tracker.py | 31 ++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/bot/modules/accountability/TimeSlot.py b/bot/modules/accountability/TimeSlot.py index 9464e4e5..81cbe38d 100644 --- a/bot/modules/accountability/TimeSlot.py +++ b/bot/modules/accountability/TimeSlot.py @@ -90,7 +90,6 @@ class TimeSlot: @property def open_embed(self): - # TODO Consider adding hint to footer timestamp = int(self.start_time.timestamp()) embed = discord.Embed( @@ -247,6 +246,34 @@ class TimeSlot: return self + async def _reload_members(self, memberids=None): + """ + Reload the timeslot members from the provided list, or data. + Also updates the channel overwrites if required. + To be used before the session has started. + """ + if self.data: + if memberids is None: + member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid) + memberids = [row.userid for row in member_rows] + + self.members = members = { + memberid: SlotMember(self.data.slotid, memberid, self.guild) + for memberid in memberids + } + + if self.channel: + # Check and potentially update overwrites + current_overwrites = self.channel.overwrites + overwrites = { + mem.member: self._member_overwrite + for mem in members.values() + if mem.member + } + overwrites[self.guild.default_role] = self._everyone_overwrite + if current_overwrites != overwrites: + await self.channel.edit(overwrites=overwrites) + def _refresh(self): """ Refresh the stored data row and reload. diff --git a/bot/modules/accountability/tracker.py b/bot/modules/accountability/tracker.py index 24e1dc94..faa82867 100644 --- a/bot/modules/accountability/tracker.py +++ b/bot/modules/accountability/tracker.py @@ -10,7 +10,7 @@ from discord.utils import sleep_until from meta import client from utils.interactive import discord_shield from data import NULL, NOTNULL, tables -from data.conditions import LEQ +from data.conditions import LEQ, THIS_SHARD from settings import GuildSettings from .TimeSlot import TimeSlot @@ -67,7 +67,8 @@ async def open_next(start_time): """ # Pre-fetch the new slot data, also populating the table caches room_data = accountability_rooms.fetch_rows_where( - start_at=start_time + start_at=start_time, + guildid=THIS_SHARD ) guild_rows = {row.guildid: row for row in room_data} member_data = accountability_members.fetch_rows_where( @@ -193,11 +194,30 @@ async def turnover(): # TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members. # We could break up the session starting? - # Move members of the next session over to the session channel + # ---------- Start next session ---------- current_slots = [ aguild.current_slot for aguild in AccountabilityGuild.cache.values() if aguild.current_slot is not None ] + slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data} + + # Reload the slot members in case they cancelled from another shard + member_data = accountability_members.fetch_rows_where( + slotid=list(slotmap.keys()) + ) if slotmap else [] + slot_memberids = {slotid: [] for slotid in slotmap} + for row in member_data: + slot_memberids[row.slotid].append(row.userid) + reload_tasks = ( + slot._reload_members(memberids=slot_memberids[slotid]) + for slotid, slot in slotmap.items() + ) + await asyncio.gather( + *reload_tasks, + return_exceptions=True + ) + + # Move members of the next session over to the session channel movement_tasks = ( mem.member.edit( voice_channel=slot.channel, @@ -335,6 +355,7 @@ async def _accountability_system_resume(): open_room_data = accountability_rooms.fetch_rows_where( closed_at=NULL, start_at=LEQ(now), + guildid=THIS_SHARD, _extra="ORDER BY start_at ASC" ) @@ -450,8 +471,10 @@ async def launch_accountability_system(client): """ # Load the AccountabilityGuild cache guilds = tables.guild_config.fetch_rows_where( - accountability_category=NOTNULL + accountability_category=NOTNULL, + guildid=THIS_SHARD ) + # Further filter out any guilds that we aren't in [AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)] await _accountability_system_resume() asyncio.create_task(_accountability_loop())