From d498673020e7144ed59865ab1746d97307c1c529 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 10:34:34 +0200 Subject: [PATCH 1/8] sharding (data): Add `SHARDID` condition. --- bot/data/conditions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bot/data/conditions.py b/bot/data/conditions.py index 4687a929..fdd0739f 100644 --- a/bot/data/conditions.py +++ b/bot/data/conditions.py @@ -70,5 +70,17 @@ class Constant(Condition): conditions.append("{} {}".format(key, self.value)) +class SHARDID(Condition): + __slots__ = ('shardid', 'shard_count') + + def __init__(self, shardid, shard_count): + self.shardid = shardid + self.shard_count = shard_count + + def apply(self, key, values, conditions): + conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) + values.append(self.shardid) + + NULL = Constant('IS NULL') NOTNULL = Constant('IS NOT NULL') From 20697c48231eb5f94fd4bdf453d5f67fe8b71a30 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 11:28:43 +0200 Subject: [PATCH 2/8] sharding (core): Add base sharding support. Add `meta.args` for command line argument access. Add command line argument support for shard number. Add shard count to config file. Add `meta.sharding` exposing shard properties. Add shard number to logging methods. Add shard number to data appid. --- bot/data/__init__.py | 2 +- bot/data/conditions.py | 5 +++++ bot/main.py | 9 +++++++-- bot/meta/__init__.py | 4 +++- bot/meta/args.py | 19 +++++++++++++++++++ bot/meta/client.py | 15 +++++++++------ bot/meta/config.py | 8 ++++---- bot/meta/logger.py | 15 +++++++++++++-- bot/meta/sharding.py | 9 +++++++++ config/example-bot.conf | 3 +++ 10 files changed, 73 insertions(+), 16 deletions(-) create mode 100644 bot/meta/args.py create mode 100644 bot/meta/sharding.py diff --git a/bot/data/__init__.py b/bot/data/__init__.py index f048ce37..2deecc48 100644 --- a/bot/data/__init__.py +++ b/bot/data/__init__.py @@ -1,5 +1,5 @@ +from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa from .connection import conn # noqa from .formatters import UpdateValue, UpdateValueAdd # noqa from .interfaces import Table, RowTable, Row, tables # noqa from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa -from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa diff --git a/bot/data/conditions.py b/bot/data/conditions.py index fdd0739f..52999504 100644 --- a/bot/data/conditions.py +++ b/bot/data/conditions.py @@ -1,5 +1,7 @@ from .connection import _replace_char +from meta import sharding + class Condition: """ @@ -82,5 +84,8 @@ class SHARDID(Condition): values.append(self.shardid) +THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) + + NULL = Constant('IS NULL') NOTNULL = Constant('IS NOT NULL') diff --git a/bot/main.py b/bot/main.py index ac818e36..066bf86e 100644 --- a/bot/main.py +++ b/bot/main.py @@ -1,4 +1,4 @@ -from meta import client, conf, log +from meta import client, conf, log, sharding from data import tables @@ -7,7 +7,12 @@ import core # noqa import modules # noqa # Load and attach app specific data -client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid']) +if sharding.sharded: + appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}" +else: + appname = conf.bot['data_appid'] +client.appdata = core.data.meta.fetch_or_create(appname) + client.data = tables # Initialise all modules diff --git a/bot/meta/__init__.py b/bot/meta/__init__.py index dd852d4f..eab9c7b8 100644 --- a/bot/meta/__init__.py +++ b/bot/meta/__init__.py @@ -1,3 +1,5 @@ +from .logger import log, logger from .client import client from .config import conf -from .logger import log, logger +from .args import args +from . import sharding diff --git a/bot/meta/args.py b/bot/meta/args.py new file mode 100644 index 00000000..c2dd70d6 --- /dev/null +++ b/bot/meta/args.py @@ -0,0 +1,19 @@ +import argparse + +from constants import CONFIG_FILE + +# ------------------------------ +# Parsed commandline arguments +# ------------------------------ +parser = argparse.ArgumentParser() +parser.add_argument('--conf', + dest='config', + default=CONFIG_FILE, + help="Path to configuration file.") +parser.add_argument('--shard', + dest='shard', + default=None, + type=int, + help="Shard number to run, if applicable.") + +args = parser.parse_args() diff --git a/bot/meta/client.py b/bot/meta/client.py index 5310171d..50414aa8 100644 --- a/bot/meta/client.py +++ b/bot/meta/client.py @@ -1,16 +1,19 @@ from discord import Intents from cmdClient.cmdClient import cmdClient -from .config import Conf +from .config import conf +from .sharding import shard_number, shard_count -from constants import CONFIG_FILE - -# Initialise config -conf = Conf(CONFIG_FILE) # Initialise client owners = [int(owner) for owner in conf.bot.getlist('owners')] intents = Intents.all() intents.presences = False -client = cmdClient(prefix=conf.bot['prefix'], owners=owners, intents=intents) +client = cmdClient( + prefix=conf.bot['prefix'], + owners=owners, + intents=intents, + shard_id=shard_number, + shard_count=shard_count +) client.conf = conf diff --git a/bot/meta/config.py b/bot/meta/config.py index a94d2b1a..ca779924 100644 --- a/bot/meta/config.py +++ b/bot/meta/config.py @@ -1,9 +1,6 @@ import configparser as cfgp - -conf = None # type: Conf - -CONF_FILE = "bot/bot.conf" +from .args import args class Conf: @@ -57,3 +54,6 @@ class Conf: def write(self): with open(self.configfile, 'w') as conffile: self.config.write(conffile) + + +conf = Conf(args.config) diff --git a/bot/meta/logger.py b/bot/meta/logger.py index 858b1292..3e7bd026 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -9,11 +9,18 @@ from utils.lib import mail, split_text from .client import client from .config import conf +from . import sharding # Setup the logger logger = logging.getLogger() -log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{') +log_fmt = logging.Formatter( + fmt=('[{asctime}][{levelname:^8}]' + + '[SHARD {}]'.format(sharding.shard_number) + + '{message}'), + datefmt='%d/%m | %H:%M:%S', + style='{' +) # term_handler = logging.StreamHandler(sys.stdout) # term_handler.setFormatter(log_fmt) # logger.addHandler(term_handler) @@ -77,7 +84,11 @@ async def live_log(message, context, level): log_chid = conf.bot.getint('log_channel') # Generate the log messages - header = "[{}][{}]".format(logging.getLevelName(level), str(context)) + if sharding.sharded: + header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]" + else: + header = f"[{logging.getLevelName(level)}][{context}]" + if len(message) > 1900: blocks = split_text(message, blocksize=1900, code=False) else: diff --git a/bot/meta/sharding.py b/bot/meta/sharding.py new file mode 100644 index 00000000..ffe86a89 --- /dev/null +++ b/bot/meta/sharding.py @@ -0,0 +1,9 @@ +from .args import args +from .config import conf + + +shard_number = args.shard or 0 + +shard_count = conf.bot.getint('shard_count', 1) + +sharded = (shard_count > 0) diff --git a/config/example-bot.conf b/config/example-bot.conf index b2fc7a48..d2ec5dd1 100644 --- a/config/example-bot.conf +++ b/config/example-bot.conf @@ -1,6 +1,7 @@ [DEFAULT] log_file = bot.log log_channel = +error_channel = guild_log_channel = prefix = ! @@ -10,4 +11,6 @@ owners = 413668234269818890, 389399222400712714 database = dbname=lionbot data_appid = LionBot +shard_count = 1 + lion_sync_period = 60 From 1c05d7a88072f4a134217525a92590b86aa0ff96 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 13:07:20 +0200 Subject: [PATCH 3/8] sharding (blacklists): Blacklist shard support. Moved the `user_blacklist` and `guild_blacklist` to a client TTL cache. --- bot/LionModule.py | 4 +- bot/core/blacklists.py | 44 ++++++++----------- bot/modules/economy/cointop_cmd.py | 2 +- bot/modules/reminders/reminder.py | 2 +- bot/modules/study/stats_cmd.py | 2 +- bot/modules/study/top_cmd.py | 2 +- bot/modules/study/tracking/session_tracker.py | 2 +- bot/modules/study/tracking/time_tracker.py | 2 +- bot/modules/sysadmin/blacklist.py | 33 +++++++++----- bot/modules/workout/tracker.py | 2 +- 10 files changed, 48 insertions(+), 47 deletions(-) diff --git a/bot/LionModule.py b/bot/LionModule.py index 6e89cf5e..9c2a4671 100644 --- a/bot/LionModule.py +++ b/bot/LionModule.py @@ -82,7 +82,7 @@ class LionModule(Module): raise SafeCancellation(details="Module '{}' is not ready.".format(self.name)) # Check global user blacklist - if ctx.author.id in ctx.client.objects['blacklisted_users']: + if ctx.author.id in ctx.client.user_blacklist(): raise SafeCancellation(details='User is blacklisted.') if ctx.guild: @@ -91,7 +91,7 @@ class LionModule(Module): raise SafeCancellation(details='Command channel is no longer reachable.') # Check global guild blacklist - if ctx.guild.id in ctx.client.objects['blacklisted_guilds']: + if ctx.guild.id in ctx.client.guild_blacklist(): raise SafeCancellation(details='Guild is blacklisted.') # Check guild's own member blacklist diff --git a/bot/core/blacklists.py b/bot/core/blacklists.py index 942bd012..1ca5bd9c 100644 --- a/bot/core/blacklists.py +++ b/bot/core/blacklists.py @@ -1,9 +1,8 @@ """ Guild, user, and member blacklists. - -NOTE: The pre-loading methods are not shard-optimised. """ from collections import defaultdict +import cachetools.func from data import tables from meta import client @@ -11,32 +10,22 @@ from meta import client from .module import module -@module.init_task -def load_guild_blacklist(client): +@cachetools.func.ttl_cache(ttl=300) +def guild_blacklist(): """ - Load the blacklisted guilds. + Get the guild blacklist """ rows = tables.global_guild_blacklist.select_where() - client.objects['blacklisted_guilds'] = set(row['guildid'] for row in rows) - if rows: - client.log( - "Loaded {} blacklisted guilds.".format(len(rows)), - context="GUILD_BLACKLIST" - ) + return set(row['guildid'] for row in rows) -@module.init_task -def load_user_blacklist(client): +@cachetools.func.ttl_cache(ttl=300) +def user_blacklist(): """ - Load the blacklisted users. + Get the global user blacklist. """ rows = tables.global_user_blacklist.select_where() - client.objects['blacklisted_users'] = set(row['userid'] for row in rows) - if rows: - client.log( - "Loaded {} globally blacklisted users.".format(len(rows)), - context="USER_BLACKLIST" - ) + return set(row['userid'] for row in rows) @module.init_task @@ -62,18 +51,20 @@ def load_ignored_members(client): ) +@module.init_task +def attach_client_blacklists(client): + client.guild_blacklist = guild_blacklist + client.user_blacklist = user_blacklist + + @module.launch_task async def leave_blacklisted_guilds(client): """ Launch task to leave any blacklisted guilds we are in. - Assumes that the blacklisted guild list has been initialised. """ - # Cache to avoic repeated lookups - blacklisted = client.objects['blacklisted_guilds'] - to_leave = [ guild for guild in client.guilds - if guild.id in blacklisted + if guild.id in guild_blacklist() ] for guild in to_leave: @@ -92,7 +83,8 @@ async def check_guild_blacklist(client, guild): Guild join event handler to check whether the guild is blacklisted. If so, leaves the guild. """ - if guild.id in client.objects['blacklisted_guilds']: + # First refresh the blacklist cache + if guild.id in guild_blacklist(): await guild.leave() client.log( "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id), diff --git a/bot/modules/economy/cointop_cmd.py b/bot/modules/economy/cointop_cmd.py index 9d1b9b2d..81bdbad9 100644 --- a/bot/modules/economy/cointop_cmd.py +++ b/bot/modules/economy/cointop_cmd.py @@ -43,7 +43,7 @@ async def cmd_topcoin(ctx): # Fetch the leaderboard exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) args = { diff --git a/bot/modules/reminders/reminder.py b/bot/modules/reminders/reminder.py index d3e4f764..73870341 100644 --- a/bot/modules/reminders/reminder.py +++ b/bot/modules/reminders/reminder.py @@ -134,7 +134,7 @@ class Reminder: """ Execute the reminder. """ - if self.data.userid in client.objects['blacklisted_users']: + if self.data.userid in client.user_blacklist(): self.delete(self.reminderid) return diff --git a/bot/modules/study/stats_cmd.py b/bot/modules/study/stats_cmd.py index 29e412e6..88bc8be5 100644 --- a/bot/modules/study/stats_cmd.py +++ b/bot/modules/study/stats_cmd.py @@ -59,7 +59,7 @@ async def cmd_stats(ctx): # Leaderboard ranks exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) if target.id in exclude: time_rank = None diff --git a/bot/modules/study/top_cmd.py b/bot/modules/study/top_cmd.py index cb4008f7..79564c1f 100644 --- a/bot/modules/study/top_cmd.py +++ b/bot/modules/study/top_cmd.py @@ -40,7 +40,7 @@ async def cmd_top(ctx): # Fetch the leaderboard exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) - exclude.update(ctx.client.objects['blacklisted_users']) + exclude.update(ctx.client.user_blacklist()) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) args = { diff --git a/bot/modules/study/tracking/session_tracker.py b/bot/modules/study/tracking/session_tracker.py index 96262f04..9cbe53be 100644 --- a/bot/modules/study/tracking/session_tracker.py +++ b/bot/modules/study/tracking/session_tracker.py @@ -298,7 +298,7 @@ async def session_voice_tracker(client, member, before, after): pending.cancel() if after.channel: - blacklist = client.objects['blacklisted_users'] + blacklist = client.user_blacklist() guild_blacklist = client.objects['ignored_members'][guild.id] untracked = untracked_channels.get(guild.id).data start_session = ( diff --git a/bot/modules/study/tracking/time_tracker.py b/bot/modules/study/tracking/time_tracker.py index 1cb35fa0..46f88ec7 100644 --- a/bot/modules/study/tracking/time_tracker.py +++ b/bot/modules/study/tracking/time_tracker.py @@ -47,7 +47,7 @@ def _scan(guild): members = itertools.chain(*channel_members) # TODO filter out blacklisted users - blacklist = client.objects['blacklisted_users'] + blacklist = client.user_blacklist() guild_blacklist = client.objects['ignored_members'][guild.id] for member in members: diff --git a/bot/modules/sysadmin/blacklist.py b/bot/modules/sysadmin/blacklist.py index 12a2ed9b..90202407 100644 --- a/bot/modules/sysadmin/blacklist.py +++ b/bot/modules/sysadmin/blacklist.py @@ -7,6 +7,8 @@ import discord from cmdClient.checks import is_owner from cmdClient.lib import ResponseTimedOut +from meta.sharding import sharded + from .module import module @@ -26,14 +28,14 @@ async def cmd_guildblacklist(ctx, flags): Description: View, add, or remove guilds from the blacklist. """ - blacklist = ctx.client.objects['blacklisted_guilds'] + blacklist = ctx.client.guild_blacklist() if ctx.args: # guildid parsing items = [item.strip() for item in ctx.args.split(',')] if any(not item.isdigit() for item in items): return await ctx.error_reply( - "Please provide guilds as comma seprated guild ids." + "Please provide guilds as comma separated guild ids." ) guildids = set(int(item) for item in items) @@ -80,9 +82,18 @@ async def cmd_guildblacklist(ctx, flags): insert_keys=('guildid', 'ownerid', 'reason') ) - # Check if we are in any of these guilds - to_leave = (ctx.client.get_guild(guildid) for guildid in to_add) - to_leave = [guild for guild in to_leave if guild is not None] + # Leave freshly blacklisted guilds, accounting for shards + to_leave = [] + for guildid in to_add: + guild = ctx.client.get_guild(guildid) + if not guild and sharded: + try: + guild = await ctx.client.fetch_guild(guildid) + except discord.HTTPException: + pass + if guild: + to_leave.append(guild) + for guild in to_leave: await guild.leave() @@ -102,9 +113,8 @@ async def cmd_guildblacklist(ctx, flags): ) # Refresh the cached blacklist after modification - ctx.client.objects['blacklisted_guilds'] = set( - row['guildid'] for row in ctx.client.data.global_guild_blacklist.select_where() - ) + ctx.client.guild_blacklist.cache_clear() + ctx.client.guild_blacklist() else: # Display the current blacklist # First fetch the full blacklist data @@ -183,7 +193,7 @@ async def cmd_userblacklist(ctx, flags): Description: View, add, or remove users from the blacklist. """ - blacklist = ctx.client.objects['blacklisted_users'] + blacklist = ctx.client.user_blacklist() if ctx.args: # userid parsing @@ -245,9 +255,8 @@ async def cmd_userblacklist(ctx, flags): ) # Refresh the cached blacklist after modification - ctx.client.objects['blacklisted_users'] = set( - row['userid'] for row in ctx.client.data.global_user_blacklist.select_where() - ) + ctx.client.user_blacklist.cache_clear() + ctx.client.user_blacklist() else: # Display the current blacklist # First fetch the full blacklist data diff --git a/bot/modules/workout/tracker.py b/bot/modules/workout/tracker.py index 90eea397..79dc9378 100644 --- a/bot/modules/workout/tracker.py +++ b/bot/modules/workout/tracker.py @@ -170,7 +170,7 @@ async def workout_voice_tracker(client, member, before, after): if member.bot: return - if member.id in client.objects['blacklisted_users']: + if member.id in client.user_blacklist(): return if member.id in client.objects['ignored_members'][member.guild.id]: return From 25e22c07d0ad72590695df53b2a3b310592d7549 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 13:20:27 +0200 Subject: [PATCH 4/8] sharding (tickets): Filter expiring tickets. Only expire tickets which are on this shard. `THIS_SHARD` application is a no-op when unsharded. --- bot/data/conditions.py | 5 +++-- bot/modules/moderation/tickets/Ticket.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bot/data/conditions.py b/bot/data/conditions.py index 52999504..a314616e 100644 --- a/bot/data/conditions.py +++ b/bot/data/conditions.py @@ -80,8 +80,9 @@ class SHARDID(Condition): self.shard_count = shard_count def apply(self, key, values, conditions): - conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) - values.append(self.shardid) + if self.shard_count > 1: + conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char)) + values.append(self.shardid) THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count) diff --git a/bot/modules/moderation/tickets/Ticket.py b/bot/modules/moderation/tickets/Ticket.py index 4d7ec5ec..afea1eef 100644 --- a/bot/modules/moderation/tickets/Ticket.py +++ b/bot/modules/moderation/tickets/Ticket.py @@ -6,6 +6,7 @@ import datetime import discord from meta import client +from data.conditions import THIS_SHARD from settings import GuildSettings from utils.lib import FieldEnum, strfdelta, utc_now @@ -283,7 +284,8 @@ class Ticket: # Get all expiring tickets expiring_rows = data.tickets.select_where( - ticket_state=TicketState.EXPIRING + ticket_state=TicketState.EXPIRING, + guildid=THIS_SHARD ) # Create new expiry tasks From 276886a3a70b2da61e34aa34cd444fefac73131b Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 17:26:49 +0200 Subject: [PATCH 5/8] sharding (general): Add launch data filters. Filter cached reaction role messages by shardid. Filter expiring rented room by shardid. Filter scanned study badges by shardid. Filter resumed study sessions by shardid. Filter resumed workouts by shardid. Fix a spacing issue in the log printer. --- bot/meta/logger.py | 2 +- bot/modules/guild_admin/reaction_roles/tracker.py | 3 ++- bot/modules/renting/rooms.py | 3 ++- bot/modules/study/badges/badge_tracker.py | 12 ++++++++---- bot/modules/study/tracking/session_tracker.py | 3 ++- bot/modules/workout/tracker.py | 4 +++- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/bot/meta/logger.py b/bot/meta/logger.py index 3e7bd026..a95500e4 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -17,7 +17,7 @@ logger = logging.getLogger() log_fmt = logging.Formatter( fmt=('[{asctime}][{levelname:^8}]' + '[SHARD {}]'.format(sharding.shard_number) + - '{message}'), + ' {message}'), datefmt='%d/%m | %H:%M:%S', style='{' ) diff --git a/bot/modules/guild_admin/reaction_roles/tracker.py b/bot/modules/guild_admin/reaction_roles/tracker.py index f18e3c34..17a64960 100644 --- a/bot/modules/guild_admin/reaction_roles/tracker.py +++ b/bot/modules/guild_admin/reaction_roles/tracker.py @@ -12,6 +12,7 @@ from discord import PartialEmoji from meta import client from core import Lion from data import Row +from data.conditions import THIS_SHARD from utils.lib import utc_now from settings import GuildSettings @@ -584,5 +585,5 @@ def load_reaction_roles(client): """ Load the ReactionRoleMessages. """ - rows = reaction_role_messages.fetch_rows_where() + rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD) ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows} diff --git a/bot/modules/renting/rooms.py b/bot/modules/renting/rooms.py index 3e1d19c4..a8c29876 100644 --- a/bot/modules/renting/rooms.py +++ b/bot/modules/renting/rooms.py @@ -5,6 +5,7 @@ import datetime from cmdClient.lib import SafeCancellation from meta import client +from data.conditions import THIS_SHARD from settings import GuildSettings from .data import rented, rented_members @@ -276,7 +277,7 @@ class Room: @module.launch_task async def load_rented_rooms(client): - rows = rented.fetch_rows_where() + rows = rented.fetch_rows_where(guildid=THIS_SHARD) for row in rows: Room(row.channelid).schedule() client.log( diff --git a/bot/modules/study/badges/badge_tracker.py b/bot/modules/study/badges/badge_tracker.py index 2c0d33fb..721f3962 100644 --- a/bot/modules/study/badges/badge_tracker.py +++ b/bot/modules/study/badges/badge_tracker.py @@ -6,8 +6,8 @@ import contextlib import discord -from meta import client -from data.conditions import GEQ +from meta import client, sharding +from data.conditions import GEQ, THIS_SHARD from core.data import lions from utils.lib import strfdur from settings import GuildSettings @@ -54,12 +54,16 @@ async def update_study_badges(full=False): # Retrieve member rows with out of date study badges if not full and client.appdata.last_study_badge_scan is not None: + # TODO: _extra here is a hack to cover for inflexible conditionals update_rows = new_study_badges.select_where( + guildid=THIS_SHARD, _timestamp=GEQ(client.appdata.last_study_badge_scan or 0), - _extra="OR session_start IS NOT NULL" + _extra="OR session_start IS NOT NULL AND (guildid >> 22) %% {} = {}".format( + sharding.shard_count, sharding.shard_number + ) ) else: - update_rows = new_study_badges.select_where() + update_rows = new_study_badges.select_where(guildid=THIS_SHARD) if not update_rows: client.appdata.last_study_badge_scan = datetime.datetime.utcnow() diff --git a/bot/modules/study/tracking/session_tracker.py b/bot/modules/study/tracking/session_tracker.py index 9cbe53be..8158f96a 100644 --- a/bot/modules/study/tracking/session_tracker.py +++ b/bot/modules/study/tracking/session_tracker.py @@ -7,6 +7,7 @@ from collections import defaultdict from utils.lib import utc_now from data import tables +from data.conditions import THIS_SHARD from core import Lion from meta import client @@ -398,7 +399,7 @@ async def _init_session_tracker(client): ended = 0 # Grab all ongoing sessions from data - rows = current_sessions.fetch_rows_where() + rows = current_sessions.fetch_rows_where(guildid=THIS_SHARD) # Iterate through, resume or end as needed for row in rows: diff --git a/bot/modules/workout/tracker.py b/bot/modules/workout/tracker.py index 79dc9378..be3438df 100644 --- a/bot/modules/workout/tracker.py +++ b/bot/modules/workout/tracker.py @@ -7,6 +7,7 @@ from core import Lion from settings import GuildSettings from meta import client from data import NULL, tables +from data.conditions import THIS_SHARD from .module import module from .data import workout_sessions @@ -226,7 +227,8 @@ async def load_workouts(client): client.objects['current_workouts'] = {} # (guildid, userid) -> Row # Process any incomplete workouts workouts = workout_sessions.fetch_rows_where( - duration=NULL + duration=NULL, + guildid=THIS_SHARD ) count = 0 for workout in workouts: From 68ff40cb0b1b9a5e8a66d32da5814fec1c0b8b0d Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 17:42:33 +0200 Subject: [PATCH 6/8] sharding (status): Use sessions for bot status. Uses the `current_sessions` table to generate the status summary. --- bot/modules/sysadmin/status.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/bot/modules/sysadmin/status.py b/bot/modules/sysadmin/status.py index 83f02cc8..853f6410 100644 --- a/bot/modules/sysadmin/status.py +++ b/bot/modules/sysadmin/status.py @@ -13,19 +13,13 @@ async def update_status(): # TODO: Make globally configurable and saveable global _last_update - if time.time() - _last_update < 30: + if time.time() - _last_update < 60: return _last_update = time.time() - student_count = sum( - len(ch.members) - for guild in client.guilds - for ch in guild.voice_channels - ) - room_count = sum( - len([vc for vc in guild.voice_channels if vc.members]) - for guild in client.guilds + student_count, room_count = client.data.current_sessions.select_one_where( + select_columns=("COUNT(*) AS studying_count", "COUNT(DISTINCT(channelid)) AS channel_count"), ) status = "{} students in {} study rooms!".format(student_count, room_count) From 0dd5213f13a8ca5793ae9ff5bbbf498e5e94bee4 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 19:07:28 +0200 Subject: [PATCH 7/8] sharding (accountability): Adapt for sharding. Filter initially loaded accountability guilds. Filter timeslots loaded in `open_next`. Reload members and overwrites on slot start. --- bot/modules/accountability/TimeSlot.py | 29 +++++++++++++++++++++++- bot/modules/accountability/tracker.py | 31 ++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/bot/modules/accountability/TimeSlot.py b/bot/modules/accountability/TimeSlot.py index 9464e4e5..81cbe38d 100644 --- a/bot/modules/accountability/TimeSlot.py +++ b/bot/modules/accountability/TimeSlot.py @@ -90,7 +90,6 @@ class TimeSlot: @property def open_embed(self): - # TODO Consider adding hint to footer timestamp = int(self.start_time.timestamp()) embed = discord.Embed( @@ -247,6 +246,34 @@ class TimeSlot: return self + async def _reload_members(self, memberids=None): + """ + Reload the timeslot members from the provided list, or data. + Also updates the channel overwrites if required. + To be used before the session has started. + """ + if self.data: + if memberids is None: + member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid) + memberids = [row.userid for row in member_rows] + + self.members = members = { + memberid: SlotMember(self.data.slotid, memberid, self.guild) + for memberid in memberids + } + + if self.channel: + # Check and potentially update overwrites + current_overwrites = self.channel.overwrites + overwrites = { + mem.member: self._member_overwrite + for mem in members.values() + if mem.member + } + overwrites[self.guild.default_role] = self._everyone_overwrite + if current_overwrites != overwrites: + await self.channel.edit(overwrites=overwrites) + def _refresh(self): """ Refresh the stored data row and reload. diff --git a/bot/modules/accountability/tracker.py b/bot/modules/accountability/tracker.py index 24e1dc94..faa82867 100644 --- a/bot/modules/accountability/tracker.py +++ b/bot/modules/accountability/tracker.py @@ -10,7 +10,7 @@ from discord.utils import sleep_until from meta import client from utils.interactive import discord_shield from data import NULL, NOTNULL, tables -from data.conditions import LEQ +from data.conditions import LEQ, THIS_SHARD from settings import GuildSettings from .TimeSlot import TimeSlot @@ -67,7 +67,8 @@ async def open_next(start_time): """ # Pre-fetch the new slot data, also populating the table caches room_data = accountability_rooms.fetch_rows_where( - start_at=start_time + start_at=start_time, + guildid=THIS_SHARD ) guild_rows = {row.guildid: row for row in room_data} member_data = accountability_members.fetch_rows_where( @@ -193,11 +194,30 @@ async def turnover(): # TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members. # We could break up the session starting? - # Move members of the next session over to the session channel + # ---------- Start next session ---------- current_slots = [ aguild.current_slot for aguild in AccountabilityGuild.cache.values() if aguild.current_slot is not None ] + slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data} + + # Reload the slot members in case they cancelled from another shard + member_data = accountability_members.fetch_rows_where( + slotid=list(slotmap.keys()) + ) if slotmap else [] + slot_memberids = {slotid: [] for slotid in slotmap} + for row in member_data: + slot_memberids[row.slotid].append(row.userid) + reload_tasks = ( + slot._reload_members(memberids=slot_memberids[slotid]) + for slotid, slot in slotmap.items() + ) + await asyncio.gather( + *reload_tasks, + return_exceptions=True + ) + + # Move members of the next session over to the session channel movement_tasks = ( mem.member.edit( voice_channel=slot.channel, @@ -335,6 +355,7 @@ async def _accountability_system_resume(): open_room_data = accountability_rooms.fetch_rows_where( closed_at=NULL, start_at=LEQ(now), + guildid=THIS_SHARD, _extra="ORDER BY start_at ASC" ) @@ -450,8 +471,10 @@ async def launch_accountability_system(client): """ # Load the AccountabilityGuild cache guilds = tables.guild_config.fetch_rows_where( - accountability_category=NOTNULL + accountability_category=NOTNULL, + guildid=THIS_SHARD ) + # Further filter out any guilds that we aren't in [AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)] await _accountability_system_resume() asyncio.create_task(_accountability_loop()) From e979e5cf45937cd4b7dedf183319e73bc60a6310 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 22 Dec 2021 20:24:24 +0200 Subject: [PATCH 8/8] sharding (reminders): Adapt for sharding. Restrict reminder execution to shard `0`. Add a poll on shard `0` to pick up new reminders. Check whether the reminder still exists on execution. --- bot/modules/reminders/commands.py | 8 ++-- bot/modules/reminders/reminder.py | 79 ++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/bot/modules/reminders/commands.py b/bot/modules/reminders/commands.py index 0bb98d60..c8637a04 100644 --- a/bot/modules/reminders/commands.py +++ b/bot/modules/reminders/commands.py @@ -3,6 +3,7 @@ import asyncio import datetime import discord +from meta import sharding from utils.lib import parse_dur, parse_ranges, multiselect_regex from .module import module @@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders to remove!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] if not ctx.args: lines = [] @@ -209,7 +210,8 @@ async def cmd_remindme(ctx, flags): ) # Schedule reminder - reminder.schedule() + if sharding.shard_number == 0: + reminder.schedule() # Ack embed = discord.Embed( @@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] lines = [] num_field = len(str(len(live) - 1)) diff --git a/bot/modules/reminders/reminder.py b/bot/modules/reminders/reminder.py index 73870341..67956a1d 100644 --- a/bot/modules/reminders/reminder.py +++ b/bot/modules/reminders/reminder.py @@ -1,8 +1,9 @@ import asyncio import datetime +import logging import discord -from meta import client +from meta import client, sharding from utils.lib import strfdur from .data import reminders @@ -46,7 +47,10 @@ class Reminder: cls._live_reminders[reminderid].cancel() # Remove from data - reminders.delete_where(reminderid=reminderids) + if reminderids: + return reminders.delete_where(reminderid=reminderids) + else: + return [] @property def data(self): @@ -134,10 +138,16 @@ class Reminder: """ Execute the reminder. """ + if not self.data: + # Reminder deleted elsewhere + return + if self.data.userid in client.user_blacklist(): self.delete(self.reminderid) return + userid = self.data.userid + # Build the message embed embed = discord.Embed( title="You asked me to remind you!", @@ -155,8 +165,26 @@ class Reminder: ) ) + # Update the reminder data, and reschedule if required + if self.data.interval: + next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) + rows = reminders.update_where( + {'remind_at': next_time}, + reminderid=self.reminderid + ) + self.schedule() + else: + rows = self.delete(self.reminderid) + if not rows: + # Reminder deleted elsewhere + return + # Send the message, if possible - user = self.user + if not (user := client.get_user(userid)): + try: + user = await client.fetch_user(userid) + except discord.HTTPException: + pass if user: try: await user.send(embed=embed) @@ -164,21 +192,38 @@ class Reminder: # Nothing we can really do here. Maybe tell the user about their reminder next time? pass - # Update the reminder data, and reschedule if required - if self.data.interval: - next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) - reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid) - self.schedule() - else: - self.delete(self.reminderid) + +async def reminder_poll(client): + """ + One client/shard must continually poll for new or deleted reminders. + """ + # TODO: Clean this up with database signals or IPC + while True: + await asyncio.sleep(60) + + client.log( + "Running new reminder poll.", + context="REMINDERS", + level=logging.DEBUG + ) + + rids = {row.reminderid for row in reminders.fetch_rows_where()} + + to_delete = (rid for rid in Reminder._live_reminders if rid not in rids) + Reminder.delete(*to_delete) + + [Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders] @module.launch_task async def schedule_reminders(client): - rows = reminders.fetch_rows_where() - for row in rows: - Reminder(row.reminderid).schedule() - client.log( - "Scheduled {} reminders.".format(len(rows)), - context="LAUNCH_REMINDERS" - ) + if sharding.shard_number == 0: + rows = reminders.fetch_rows_where() + for row in rows: + Reminder(row.reminderid).schedule() + client.log( + "Scheduled {} reminders.".format(len(rows)), + context="LAUNCH_REMINDERS" + ) + if sharding.sharded: + asyncio.create_task(reminder_poll(client))