diff --git a/bot/LionModule.py b/bot/LionModule.py index 6e89cf5e..9c2a4671 100644 --- a/bot/LionModule.py +++ b/bot/LionModule.py @@ -82,7 +82,7 @@ class LionModule(Module): raise SafeCancellation(details="Module '{}' is not ready.".format(self.name)) # Check global user blacklist - if ctx.author.id in ctx.client.objects['blacklisted_users']: + if ctx.author.id in ctx.client.user_blacklist(): raise SafeCancellation(details='User is blacklisted.') if ctx.guild: @@ -91,7 +91,7 @@ class LionModule(Module): raise SafeCancellation(details='Command channel is no longer reachable.') # Check global guild blacklist - if ctx.guild.id in ctx.client.objects['blacklisted_guilds']: + if ctx.guild.id in ctx.client.guild_blacklist(): raise SafeCancellation(details='Guild is blacklisted.') # Check guild's own member blacklist diff --git a/bot/core/blacklists.py b/bot/core/blacklists.py index 942bd012..1ca5bd9c 100644 --- a/bot/core/blacklists.py +++ b/bot/core/blacklists.py @@ -1,9 +1,8 @@ """ Guild, user, and member blacklists. - -NOTE: The pre-loading methods are not shard-optimised. """ from collections import defaultdict +import cachetools.func from data import tables from meta import client @@ -11,32 +10,22 @@ from meta import client from .module import module -@module.init_task -def load_guild_blacklist(client): +@cachetools.func.ttl_cache(ttl=300) +def guild_blacklist(): """ - Load the blacklisted guilds. + Get the guild blacklist """ rows = tables.global_guild_blacklist.select_where() - client.objects['blacklisted_guilds'] = set(row['guildid'] for row in rows) - if rows: - client.log( - "Loaded {} blacklisted guilds.".format(len(rows)), - context="GUILD_BLACKLIST" - ) + return set(row['guildid'] for row in rows) -@module.init_task -def load_user_blacklist(client): +@cachetools.func.ttl_cache(ttl=300) +def user_blacklist(): """ - Load the blacklisted users. + Get the global user blacklist. """ rows = tables.global_user_blacklist.select_where() - client.objects['blacklisted_users'] = set(row['userid'] for row in rows) - if rows: - client.log( - "Loaded {} globally blacklisted users.".format(len(rows)), - context="USER_BLACKLIST" - ) + return set(row['userid'] for row in rows) @module.init_task @@ -62,18 +51,20 @@ def load_ignored_members(client): ) +@module.init_task +def attach_client_blacklists(client): + client.guild_blacklist = guild_blacklist + client.user_blacklist = user_blacklist + + @module.launch_task async def leave_blacklisted_guilds(client): """ Launch task to leave any blacklisted guilds we are in. - Assumes that the blacklisted guild list has been initialised. """ - # Cache to avoic repeated lookups - blacklisted = client.objects['blacklisted_guilds'] - to_leave = [ guild for guild in client.guilds - if guild.id in blacklisted + if guild.id in guild_blacklist() ] for guild in to_leave: @@ -92,7 +83,8 @@ async def check_guild_blacklist(client, guild): Guild join event handler to check whether the guild is blacklisted. If so, leaves the guild. """ - if guild.id in client.objects['blacklisted_guilds']: + # First refresh the blacklist cache + if guild.id in guild_blacklist(): await guild.leave() client.log( "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id), diff --git a/bot/data/__init__.py b/bot/data/__init__.py index f048ce37..2deecc48 100644 --- a/bot/data/__init__.py +++ b/bot/data/__init__.py @@ -1,5 +1,5 @@ +from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa from .connection import conn # noqa from .formatters import UpdateValue, UpdateValueAdd # noqa from .interfaces import Table, RowTable, Row, tables # noqa from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa -from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa diff --git a/bot/data/conditions.py b/bot/data/conditions.py index 4687a929..a314616e 100644 --- a/bot/data/conditions.py +++ b/bot/data/conditions.py @@ -1,5 +1,7 @@ from .connection import _replace_char +from meta import sharding + class Condition: """ @@ -70,5 +72,21 @@ class Constant(Condition): conditions.append("{} {}".format(key, self.value)) +class SHARDID(Condition): + __slots__ = ('shardid', 'shard_count') + + def __init__(self, shardid, shard_count): + self.shardid = shardid + self.shard_count = shard_count + + def apply(self, key, values, conditions): + if self.shard_count > 1: + conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) + values.append(self.shardid) + + +THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) + + NULL = Constant('IS NULL') NOTNULL = Constant('IS NOT NULL') diff --git a/bot/main.py b/bot/main.py index ac818e36..066bf86e 100644 --- a/bot/main.py +++ b/bot/main.py @@ -1,4 +1,4 @@ -from meta import client, conf, log +from meta import client, conf, log, sharding from data import tables @@ -7,7 +7,12 @@ import core # noqa import modules # noqa # Load and attach app specific data -client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid']) +if sharding.sharded: + appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}" +else: + appname = conf.bot['data_appid'] +client.appdata = core.data.meta.fetch_or_create(appname) + client.data = tables # Initialise all modules diff --git a/bot/meta/__init__.py b/bot/meta/__init__.py index dd852d4f..eab9c7b8 100644 --- a/bot/meta/__init__.py +++ b/bot/meta/__init__.py @@ -1,3 +1,5 @@ +from .logger import log, logger from .client import client from .config import conf -from .logger import log, logger +from .args import args +from . import sharding diff --git a/bot/meta/args.py b/bot/meta/args.py new file mode 100644 index 00000000..c2dd70d6 --- /dev/null +++ b/bot/meta/args.py @@ -0,0 +1,19 @@ +import argparse + +from constants import CONFIG_FILE + +# ------------------------------ +# Parsed commandline arguments +# ------------------------------ +parser = argparse.ArgumentParser() +parser.add_argument('--conf', + dest='config', + default=CONFIG_FILE, + help="Path to configuration file.") +parser.add_argument('--shard', + dest='shard', + default=None, + type=int, + help="Shard number to run, if applicable.") + +args = parser.parse_args() diff --git a/bot/meta/client.py b/bot/meta/client.py index 5310171d..50414aa8 100644 --- a/bot/meta/client.py +++ b/bot/meta/client.py @@ -1,16 +1,19 @@ from discord import Intents from cmdClient.cmdClient import cmdClient -from .config import Conf +from .config import conf +from .sharding import shard_number, shard_count -from constants import CONFIG_FILE - -# Initialise config -conf = Conf(CONFIG_FILE) # Initialise client owners = [int(owner) for owner in conf.bot.getlist('owners')] intents = Intents.all() intents.presences = False -client = cmdClient(prefix=conf.bot['prefix'], owners=owners, intents=intents) +client = cmdClient( + prefix=conf.bot['prefix'], + owners=owners, + intents=intents, + shard_id=shard_number, + shard_count=shard_count +) client.conf = conf diff --git a/bot/meta/config.py b/bot/meta/config.py index a94d2b1a..ca779924 100644 --- a/bot/meta/config.py +++ b/bot/meta/config.py @@ -1,9 +1,6 @@ import configparser as cfgp - -conf = None # type: Conf - -CONF_FILE = "bot/bot.conf" +from .args import args class Conf: @@ -57,3 +54,6 @@ class Conf: def write(self): with open(self.configfile, 'w') as conffile: self.config.write(conffile) + + +conf = Conf(args.config) diff --git a/bot/meta/logger.py b/bot/meta/logger.py index 858b1292..a95500e4 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -9,11 +9,18 @@ from utils.lib import mail, split_text from .client import client from .config import conf +from . import sharding # Setup the logger logger = logging.getLogger() -log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{') +log_fmt = logging.Formatter( + fmt=('[{asctime}][{levelname:^8}]' + + '[SHARD {}]'.format(sharding.shard_number) + + ' {message}'), + datefmt='%d/%m | %H:%M:%S', + style='{' +) # term_handler = logging.StreamHandler(sys.stdout) # term_handler.setFormatter(log_fmt) # logger.addHandler(term_handler) @@ -77,7 +84,11 @@ async def live_log(message, context, level): log_chid = conf.bot.getint('log_channel') # Generate the log messages - header = "[{}][{}]".format(logging.getLevelName(level), str(context)) + if sharding.sharded: + header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]" + else: + header = f"[{logging.getLevelName(level)}][{context}]" + if len(message) > 1900: blocks = split_text(message, blocksize=1900, code=False) else: diff --git a/bot/meta/sharding.py b/bot/meta/sharding.py new file mode 100644 index 00000000..ffe86a89 --- /dev/null +++ b/bot/meta/sharding.py @@ -0,0 +1,9 @@ +from .args import args +from .config import conf + + +shard_number = args.shard or 0 + +shard_count = conf.bot.getint('shard_count', 1) + +sharded = (shard_count > 0) diff --git a/bot/modules/accountability/TimeSlot.py b/bot/modules/accountability/TimeSlot.py index 9464e4e5..81cbe38d 100644 --- a/bot/modules/accountability/TimeSlot.py +++ b/bot/modules/accountability/TimeSlot.py @@ -90,7 +90,6 @@ class TimeSlot: @property def open_embed(self): - # TODO Consider adding hint to footer timestamp = int(self.start_time.timestamp()) embed = discord.Embed( @@ -247,6 +246,34 @@ class TimeSlot: return self + async def _reload_members(self, memberids=None): + """ + Reload the timeslot members from the provided list, or data. + Also updates the channel overwrites if required. + To be used before the session has started. + """ + if self.data: + if memberids is None: + member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid) + memberids = [row.userid for row in member_rows] + + self.members = members = { + memberid: SlotMember(self.data.slotid, memberid, self.guild) + for memberid in memberids + } + + if self.channel: + # Check and potentially update overwrites + current_overwrites = self.channel.overwrites + overwrites = { + mem.member: self._member_overwrite + for mem in members.values() + if mem.member + } + overwrites[self.guild.default_role] = self._everyone_overwrite + if current_overwrites != overwrites: + await self.channel.edit(overwrites=overwrites) + def _refresh(self): """ Refresh the stored data row and reload. diff --git a/bot/modules/accountability/tracker.py b/bot/modules/accountability/tracker.py index 24e1dc94..faa82867 100644 --- a/bot/modules/accountability/tracker.py +++ b/bot/modules/accountability/tracker.py @@ -10,7 +10,7 @@ from discord.utils import sleep_until from meta import client from utils.interactive import discord_shield from data import NULL, NOTNULL, tables -from data.conditions import LEQ +from data.conditions import LEQ, THIS_SHARD from settings import GuildSettings from .TimeSlot import TimeSlot @@ -67,7 +67,8 @@ async def open_next(start_time): """ # Pre-fetch the new slot data, also populating the table caches room_data = accountability_rooms.fetch_rows_where( - start_at=start_time + start_at=start_time, + guildid=THIS_SHARD ) guild_rows = {row.guildid: row for row in room_data} member_data = accountability_members.fetch_rows_where( @@ -193,11 +194,30 @@ async def turnover(): # TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members. # We could break up the session starting? - # Move members of the next session over to the session channel + # ---------- Start next session ---------- current_slots = [ aguild.current_slot for aguild in AccountabilityGuild.cache.values() if aguild.current_slot is not None ] + slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data} + + # Reload the slot members in case they cancelled from another shard + member_data = accountability_members.fetch_rows_where( + slotid=list(slotmap.keys()) + ) if slotmap else [] + slot_memberids = {slotid: [] for slotid in slotmap} + for row in member_data: + slot_memberids[row.slotid].append(row.userid) + reload_tasks = ( + slot._reload_members(memberids=slot_memberids[slotid]) + for slotid, slot in slotmap.items() + ) + await asyncio.gather( + *reload_tasks, + return_exceptions=True + ) + + # Move members of the next session over to the session channel movement_tasks = ( mem.member.edit( voice_channel=slot.channel, @@ -335,6 +355,7 @@ async def _accountability_system_resume(): open_room_data = accountability_rooms.fetch_rows_where( closed_at=NULL, start_at=LEQ(now), + guildid=THIS_SHARD, _extra="ORDER BY start_at ASC" ) @@ -450,8 +471,10 @@ async def launch_accountability_system(client): """ # Load the AccountabilityGuild cache guilds = tables.guild_config.fetch_rows_where( - accountability_category=NOTNULL + accountability_category=NOTNULL, + guildid=THIS_SHARD ) + # Further filter out any guilds that we aren't in [AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)] await _accountability_system_resume() asyncio.create_task(_accountability_loop()) diff --git a/bot/modules/economy/cointop_cmd.py b/bot/modules/economy/cointop_cmd.py index 9d1b9b2d..81bdbad9 100644 --- a/bot/modules/economy/cointop_cmd.py +++ b/bot/modules/economy/cointop_cmd.py @@ -43,7 +43,7 @@ async def cmd_topcoin(ctx): # Fetch the leaderboard exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) args = { diff --git a/bot/modules/guild_admin/reaction_roles/tracker.py b/bot/modules/guild_admin/reaction_roles/tracker.py index f18e3c34..17a64960 100644 --- a/bot/modules/guild_admin/reaction_roles/tracker.py +++ b/bot/modules/guild_admin/reaction_roles/tracker.py @@ -12,6 +12,7 @@ from discord import PartialEmoji from meta import client from core import Lion from data import Row +from data.conditions import THIS_SHARD from utils.lib import utc_now from settings import GuildSettings @@ -584,5 +585,5 @@ def load_reaction_roles(client): """ Load the ReactionRoleMessages. """ - rows = reaction_role_messages.fetch_rows_where() + rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD) ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows} diff --git a/bot/modules/moderation/tickets/Ticket.py b/bot/modules/moderation/tickets/Ticket.py index 4d7ec5ec..afea1eef 100644 --- a/bot/modules/moderation/tickets/Ticket.py +++ b/bot/modules/moderation/tickets/Ticket.py @@ -6,6 +6,7 @@ import datetime import discord from meta import client +from data.conditions import THIS_SHARD from settings import GuildSettings from utils.lib import FieldEnum, strfdelta, utc_now @@ -283,7 +284,8 @@ class Ticket: # Get all expiring tickets expiring_rows = data.tickets.select_where( - ticket_state=TicketState.EXPIRING + ticket_state=TicketState.EXPIRING, + guildid=THIS_SHARD ) # Create new expiry tasks diff --git a/bot/modules/reminders/commands.py b/bot/modules/reminders/commands.py index 0bb98d60..c8637a04 100644 --- a/bot/modules/reminders/commands.py +++ b/bot/modules/reminders/commands.py @@ -3,6 +3,7 @@ import asyncio import datetime import discord +from meta import sharding from utils.lib import parse_dur, parse_ranges, multiselect_regex from .module import module @@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders to remove!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] if not ctx.args: lines = [] @@ -209,7 +210,8 @@ async def cmd_remindme(ctx, flags): ) # Schedule reminder - reminder.schedule() + if sharding.shard_number == 0: + reminder.schedule() # Ack embed = discord.Embed( @@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] lines = [] num_field = len(str(len(live) - 1)) diff --git a/bot/modules/reminders/reminder.py b/bot/modules/reminders/reminder.py index d3e4f764..67956a1d 100644 --- a/bot/modules/reminders/reminder.py +++ b/bot/modules/reminders/reminder.py @@ -1,8 +1,9 @@ import asyncio import datetime +import logging import discord -from meta import client +from meta import client, sharding from utils.lib import strfdur from .data import reminders @@ -46,7 +47,10 @@ class Reminder: cls._live_reminders[reminderid].cancel() # Remove from data - reminders.delete_where(reminderid=reminderids) + if reminderids: + return reminders.delete_where(reminderid=reminderids) + else: + return [] @property def data(self): @@ -134,10 +138,16 @@ class Reminder: """ Execute the reminder. """ - if self.data.userid in client.objects['blacklisted_users']: + if not self.data: + # Reminder deleted elsewhere + return + + if self.data.userid in client.user_blacklist(): self.delete(self.reminderid) return + userid = self.data.userid + # Build the message embed embed = discord.Embed( title="You asked me to remind you!", @@ -155,8 +165,26 @@ class Reminder: ) ) + # Update the reminder data, and reschedule if required + if self.data.interval: + next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) + rows = reminders.update_where( + {'remind_at': next_time}, + reminderid=self.reminderid + ) + self.schedule() + else: + rows = self.delete(self.reminderid) + if not rows: + # Reminder deleted elsewhere + return + # Send the message, if possible - user = self.user + if not (user := client.get_user(userid)): + try: + user = await client.fetch_user(userid) + except discord.HTTPException: + pass if user: try: await user.send(embed=embed) @@ -164,21 +192,38 @@ class Reminder: # Nothing we can really do here. Maybe tell the user about their reminder next time? pass - # Update the reminder data, and reschedule if required - if self.data.interval: - next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) - reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid) - self.schedule() - else: - self.delete(self.reminderid) + +async def reminder_poll(client): + """ + One client/shard must continually poll for new or deleted reminders. + """ + # TODO: Clean this up with database signals or IPC + while True: + await asyncio.sleep(60) + + client.log( + "Running new reminder poll.", + context="REMINDERS", + level=logging.DEBUG + ) + + rids = {row.reminderid for row in reminders.fetch_rows_where()} + + to_delete = (rid for rid in Reminder._live_reminders if rid not in rids) + Reminder.delete(*to_delete) + + [Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders] @module.launch_task async def schedule_reminders(client): - rows = reminders.fetch_rows_where() - for row in rows: - Reminder(row.reminderid).schedule() - client.log( - "Scheduled {} reminders.".format(len(rows)), - context="LAUNCH_REMINDERS" - ) + if sharding.shard_number == 0: + rows = reminders.fetch_rows_where() + for row in rows: + Reminder(row.reminderid).schedule() + client.log( + "Scheduled {} reminders.".format(len(rows)), + context="LAUNCH_REMINDERS" + ) + if sharding.sharded: + asyncio.create_task(reminder_poll(client)) diff --git a/bot/modules/renting/rooms.py b/bot/modules/renting/rooms.py index 3e1d19c4..a8c29876 100644 --- a/bot/modules/renting/rooms.py +++ b/bot/modules/renting/rooms.py @@ -5,6 +5,7 @@ import datetime from cmdClient.lib import SafeCancellation from meta import client +from data.conditions import THIS_SHARD from settings import GuildSettings from .data import rented, rented_members @@ -276,7 +277,7 @@ class Room: @module.launch_task async def load_rented_rooms(client): - rows = rented.fetch_rows_where() + rows = rented.fetch_rows_where(guildid=THIS_SHARD) for row in rows: Room(row.channelid).schedule() client.log( diff --git a/bot/modules/study/badges/badge_tracker.py b/bot/modules/study/badges/badge_tracker.py index 2c0d33fb..721f3962 100644 --- a/bot/modules/study/badges/badge_tracker.py +++ b/bot/modules/study/badges/badge_tracker.py @@ -6,8 +6,8 @@ import contextlib import discord -from meta import client -from data.conditions import GEQ +from meta import client, sharding +from data.conditions import GEQ, THIS_SHARD from core.data import lions from utils.lib import strfdur from settings import GuildSettings @@ -54,12 +54,16 @@ async def update_study_badges(full=False): # Retrieve member rows with out of date study badges if not full and client.appdata.last_study_badge_scan is not None: + # TODO: _extra here is a hack to cover for inflexible conditionals update_rows = new_study_badges.select_where( + guildid=THIS_SHARD, _timestamp=GEQ(client.appdata.last_study_badge_scan or 0), - _extra="OR session_start IS NOT NULL" + _extra="OR session_start IS NOT NULL AND (guildid >> 22) %% {} = {}".format( + sharding.shard_count, sharding.shard_number + ) ) else: - update_rows = new_study_badges.select_where() + update_rows = new_study_badges.select_where(guildid=THIS_SHARD) if not update_rows: client.appdata.last_study_badge_scan = datetime.datetime.utcnow() diff --git a/bot/modules/study/stats_cmd.py b/bot/modules/study/stats_cmd.py index 29e412e6..88bc8be5 100644 --- a/bot/modules/study/stats_cmd.py +++ b/bot/modules/study/stats_cmd.py @@ -59,7 +59,7 @@ async def cmd_stats(ctx): # Leaderboard ranks exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) if target.id in exclude: time_rank = None diff --git a/bot/modules/study/top_cmd.py b/bot/modules/study/top_cmd.py index cb4008f7..79564c1f 100644 --- a/bot/modules/study/top_cmd.py +++ b/bot/modules/study/top_cmd.py @@ -40,7 +40,7 @@ async def cmd_top(ctx): # Fetch the leaderboard exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) args = { diff --git a/bot/modules/study/tracking/session_tracker.py b/bot/modules/study/tracking/session_tracker.py index 96262f04..8158f96a 100644 --- a/bot/modules/study/tracking/session_tracker.py +++ b/bot/modules/study/tracking/session_tracker.py @@ -7,6 +7,7 @@ from collections import defaultdict from utils.lib import utc_now from data import tables +from data.conditions import THIS_SHARD from core import Lion from meta import client @@ -298,7 +299,7 @@ async def session_voice_tracker(client, member, before, after): pending.cancel() if after.channel: - blacklist = client.objects['blacklisted_users'] + blacklist = client.user_blacklist() guild_blacklist = client.objects['ignored_members'][guild.id] untracked = untracked_channels.get(guild.id).data start_session = ( @@ -398,7 +399,7 @@ async def _init_session_tracker(client): ended = 0 # Grab all ongoing sessions from data - rows = current_sessions.fetch_rows_where() + rows = current_sessions.fetch_rows_where(guildid=THIS_SHARD) # Iterate through, resume or end as needed for row in rows: diff --git a/bot/modules/study/tracking/time_tracker.py b/bot/modules/study/tracking/time_tracker.py index 1cb35fa0..46f88ec7 100644 --- a/bot/modules/study/tracking/time_tracker.py +++ b/bot/modules/study/tracking/time_tracker.py @@ -47,7 +47,7 @@ def _scan(guild): members = itertools.chain(*channel_members) # TODO filter out blacklisted users - blacklist = client.objects['blacklisted_users'] + blacklist = client.user_blacklist() guild_blacklist = client.objects['ignored_members'][guild.id] for member in members: diff --git a/bot/modules/sysadmin/blacklist.py b/bot/modules/sysadmin/blacklist.py index 12a2ed9b..90202407 100644 --- a/bot/modules/sysadmin/blacklist.py +++ b/bot/modules/sysadmin/blacklist.py @@ -7,6 +7,8 @@ import discord from cmdClient.checks import is_owner from cmdClient.lib import ResponseTimedOut +from meta.sharding import sharded + from .module import module @@ -26,14 +28,14 @@ async def cmd_guildblacklist(ctx, flags): Description: View, add, or remove guilds from the blacklist. """ - blacklist = ctx.client.objects['blacklisted_guilds'] + blacklist = ctx.client.guild_blacklist() if ctx.args: # guildid parsing items = [item.strip() for item in ctx.args.split(',')] if any(not item.isdigit() for item in items): return await ctx.error_reply( - "Please provide guilds as comma seprated guild ids." + "Please provide guilds as comma separated guild ids." ) guildids = set(int(item) for item in items) @@ -80,9 +82,18 @@ async def cmd_guildblacklist(ctx, flags): insert_keys=('guildid', 'ownerid', 'reason') ) - # Check if we are in any of these guilds - to_leave = (ctx.client.get_guild(guildid) for guildid in to_add) - to_leave = [guild for guild in to_leave if guild is not None] + # Leave freshly blacklisted guilds, accounting for shards + to_leave = [] + for guildid in to_add: + guild = ctx.client.get_guild(guildid) + if not guild and sharded: + try: + guild = await ctx.client.fetch_guild(guildid) + except discord.HTTPException: + pass + if guild: + to_leave.append(guild) + for guild in to_leave: await guild.leave() @@ -102,9 +113,8 @@ async def cmd_guildblacklist(ctx, flags): ) # Refresh the cached blacklist after modification - ctx.client.objects['blacklisted_guilds'] = set( - row['guildid'] for row in ctx.client.data.global_guild_blacklist.select_where() - ) + ctx.client.guild_blacklist.cache_clear() + ctx.client.guild_blacklist() else: # Display the current blacklist # First fetch the full blacklist data @@ -183,7 +193,7 @@ async def cmd_userblacklist(ctx, flags): Description: View, add, or remove users from the blacklist. """ - blacklist = ctx.client.objects['blacklisted_users'] + blacklist = ctx.client.user_blacklist() if ctx.args: # userid parsing @@ -245,9 +255,8 @@ async def cmd_userblacklist(ctx, flags): ) # Refresh the cached blacklist after modification - ctx.client.objects['blacklisted_users'] = set( - row['userid'] for row in ctx.client.data.global_user_blacklist.select_where() - ) + ctx.client.user_blacklist.cache_clear() + ctx.client.user_blacklist() else: # Display the current blacklist # First fetch the full blacklist data diff --git a/bot/modules/sysadmin/status.py b/bot/modules/sysadmin/status.py index 83f02cc8..853f6410 100644 --- a/bot/modules/sysadmin/status.py +++ b/bot/modules/sysadmin/status.py @@ -13,19 +13,13 @@ async def update_status(): # TODO: Make globally configurable and saveable global _last_update - if time.time() - _last_update < 30: + if time.time() - _last_update < 60: return _last_update = time.time() - student_count = sum( - len(ch.members) - for guild in client.guilds - for ch in guild.voice_channels - ) - room_count = sum( - len([vc for vc in guild.voice_channels if vc.members]) - for guild in client.guilds + student_count, room_count = client.data.current_sessions.select_one_where( + select_columns=("COUNT(*) AS studying_count", "COUNT(DISTINCT(channelid)) AS channel_count"), ) status = "{} students in {} study rooms!".format(student_count, room_count) diff --git a/bot/modules/workout/tracker.py b/bot/modules/workout/tracker.py index 90eea397..be3438df 100644 --- a/bot/modules/workout/tracker.py +++ b/bot/modules/workout/tracker.py @@ -7,6 +7,7 @@ from core import Lion from settings import GuildSettings from meta import client from data import NULL, tables +from data.conditions import THIS_SHARD from .module import module from .data import workout_sessions @@ -170,7 +171,7 @@ async def workout_voice_tracker(client, member, before, after): if member.bot: return - if member.id in client.objects['blacklisted_users']: + if member.id in client.user_blacklist(): return if member.id in client.objects['ignored_members'][member.guild.id]: return @@ -226,7 +227,8 @@ async def load_workouts(client): client.objects['current_workouts'] = {} # (guildid, userid) -> Row # Process any incomplete workouts workouts = workout_sessions.fetch_rows_where( - duration=NULL + duration=NULL, + guildid=THIS_SHARD ) count = 0 for workout in workouts: diff --git a/config/example-bot.conf b/config/example-bot.conf index b2fc7a48..d2ec5dd1 100644 --- a/config/example-bot.conf +++ b/config/example-bot.conf @@ -1,6 +1,7 @@ [DEFAULT] log_file = bot.log log_channel = +error_channel = guild_log_channel = prefix = ! @@ -10,4 +11,6 @@ owners = 413668234269818890, 389399222400712714 database = dbname=lionbot data_appid = LionBot +shard_count = 1 + lion_sync_period = 60