Merge pull request #13 from StudyLions/sharding

Sharding support.
This commit is contained in:
Interitio
2021-12-22 20:45:49 +02:00
committed by GitHub
28 changed files with 279 additions and 106 deletions

View File

@@ -82,7 +82,7 @@ class LionModule(Module):
raise SafeCancellation(details="Module '{}' is not ready.".format(self.name))
# Check global user blacklist
if ctx.author.id in ctx.client.objects['blacklisted_users']:
if ctx.author.id in ctx.client.user_blacklist():
raise SafeCancellation(details='User is blacklisted.')
if ctx.guild:
@@ -91,7 +91,7 @@ class LionModule(Module):
raise SafeCancellation(details='Command channel is no longer reachable.')
# Check global guild blacklist
if ctx.guild.id in ctx.client.objects['blacklisted_guilds']:
if ctx.guild.id in ctx.client.guild_blacklist():
raise SafeCancellation(details='Guild is blacklisted.')
# Check guild's own member blacklist

View File

@@ -1,9 +1,8 @@
"""
Guild, user, and member blacklists.
NOTE: The pre-loading methods are not shard-optimised.
"""
from collections import defaultdict
import cachetools.func
from data import tables
from meta import client
@@ -11,32 +10,22 @@ from meta import client
from .module import module
@module.init_task
def load_guild_blacklist(client):
@cachetools.func.ttl_cache(ttl=300)
def guild_blacklist():
"""
Load the blacklisted guilds.
Get the guild blacklist
"""
rows = tables.global_guild_blacklist.select_where()
client.objects['blacklisted_guilds'] = set(row['guildid'] for row in rows)
if rows:
client.log(
"Loaded {} blacklisted guilds.".format(len(rows)),
context="GUILD_BLACKLIST"
)
return set(row['guildid'] for row in rows)
@module.init_task
def load_user_blacklist(client):
@cachetools.func.ttl_cache(ttl=300)
def user_blacklist():
"""
Load the blacklisted users.
Get the global user blacklist.
"""
rows = tables.global_user_blacklist.select_where()
client.objects['blacklisted_users'] = set(row['userid'] for row in rows)
if rows:
client.log(
"Loaded {} globally blacklisted users.".format(len(rows)),
context="USER_BLACKLIST"
)
return set(row['userid'] for row in rows)
@module.init_task
@@ -62,18 +51,20 @@ def load_ignored_members(client):
)
@module.init_task
def attach_client_blacklists(client):
client.guild_blacklist = guild_blacklist
client.user_blacklist = user_blacklist
@module.launch_task
async def leave_blacklisted_guilds(client):
"""
Launch task to leave any blacklisted guilds we are in.
Assumes that the blacklisted guild list has been initialised.
"""
# Cache to avoic repeated lookups
blacklisted = client.objects['blacklisted_guilds']
to_leave = [
guild for guild in client.guilds
if guild.id in blacklisted
if guild.id in guild_blacklist()
]
for guild in to_leave:
@@ -92,7 +83,8 @@ async def check_guild_blacklist(client, guild):
Guild join event handler to check whether the guild is blacklisted.
If so, leaves the guild.
"""
if guild.id in client.objects['blacklisted_guilds']:
# First refresh the blacklist cache
if guild.id in guild_blacklist():
await guild.leave()
client.log(
"Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id),

View File

@@ -1,5 +1,5 @@
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa
from .connection import conn # noqa
from .formatters import UpdateValue, UpdateValueAdd # noqa
from .interfaces import Table, RowTable, Row, tables # noqa
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa

View File

@@ -1,5 +1,7 @@
from .connection import _replace_char
from meta import sharding
class Condition:
"""
@@ -70,5 +72,21 @@ class Constant(Condition):
conditions.append("{} {}".format(key, self.value))
class SHARDID(Condition):
__slots__ = ('shardid', 'shard_count')
def __init__(self, shardid, shard_count):
self.shardid = shardid
self.shard_count = shard_count
def apply(self, key, values, conditions):
if self.shard_count > 1:
conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
values.append(self.shardid)
THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
NULL = Constant('IS NULL')
NOTNULL = Constant('IS NOT NULL')

View File

@@ -1,4 +1,4 @@
from meta import client, conf, log
from meta import client, conf, log, sharding
from data import tables
@@ -7,7 +7,12 @@ import core # noqa
import modules # noqa
# Load and attach app specific data
client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid'])
if sharding.sharded:
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
else:
appname = conf.bot['data_appid']
client.appdata = core.data.meta.fetch_or_create(appname)
client.data = tables
# Initialise all modules

View File

@@ -1,3 +1,5 @@
from .logger import log, logger
from .client import client
from .config import conf
from .logger import log, logger
from .args import args
from . import sharding

19
bot/meta/args.py Normal file
View File

@@ -0,0 +1,19 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file.")
parser.add_argument('--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable.")
args = parser.parse_args()

View File

@@ -1,16 +1,19 @@
from discord import Intents
from cmdClient.cmdClient import cmdClient
from .config import Conf
from .config import conf
from .sharding import shard_number, shard_count
from constants import CONFIG_FILE
# Initialise config
conf = Conf(CONFIG_FILE)
# Initialise client
owners = [int(owner) for owner in conf.bot.getlist('owners')]
intents = Intents.all()
intents.presences = False
client = cmdClient(prefix=conf.bot['prefix'], owners=owners, intents=intents)
client = cmdClient(
prefix=conf.bot['prefix'],
owners=owners,
intents=intents,
shard_id=shard_number,
shard_count=shard_count
)
client.conf = conf

View File

@@ -1,9 +1,6 @@
import configparser as cfgp
conf = None # type: Conf
CONF_FILE = "bot/bot.conf"
from .args import args
class Conf:
@@ -57,3 +54,6 @@ class Conf:
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config)

View File

@@ -9,11 +9,18 @@ from utils.lib import mail, split_text
from .client import client
from .config import conf
from . import sharding
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{')
log_fmt = logging.Formatter(
fmt=('[{asctime}][{levelname:^8}]' +
'[SHARD {}]'.format(sharding.shard_number) +
' {message}'),
datefmt='%d/%m | %H:%M:%S',
style='{'
)
# term_handler = logging.StreamHandler(sys.stdout)
# term_handler.setFormatter(log_fmt)
# logger.addHandler(term_handler)
@@ -77,7 +84,11 @@ async def live_log(message, context, level):
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
header = "[{}][{}]".format(logging.getLevelName(level), str(context))
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:

9
bot/meta/sharding.py Normal file
View File

@@ -0,0 +1,9 @@
from .args import args
from .config import conf
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)

View File

@@ -90,7 +90,6 @@ class TimeSlot:
@property
def open_embed(self):
# TODO Consider adding hint to footer
timestamp = int(self.start_time.timestamp())
embed = discord.Embed(
@@ -247,6 +246,34 @@ class TimeSlot:
return self
async def _reload_members(self, memberids=None):
"""
Reload the timeslot members from the provided list, or data.
Also updates the channel overwrites if required.
To be used before the session has started.
"""
if self.data:
if memberids is None:
member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid)
memberids = [row.userid for row in member_rows]
self.members = members = {
memberid: SlotMember(self.data.slotid, memberid, self.guild)
for memberid in memberids
}
if self.channel:
# Check and potentially update overwrites
current_overwrites = self.channel.overwrites
overwrites = {
mem.member: self._member_overwrite
for mem in members.values()
if mem.member
}
overwrites[self.guild.default_role] = self._everyone_overwrite
if current_overwrites != overwrites:
await self.channel.edit(overwrites=overwrites)
def _refresh(self):
"""
Refresh the stored data row and reload.

View File

@@ -10,7 +10,7 @@ from discord.utils import sleep_until
from meta import client
from utils.interactive import discord_shield
from data import NULL, NOTNULL, tables
from data.conditions import LEQ
from data.conditions import LEQ, THIS_SHARD
from settings import GuildSettings
from .TimeSlot import TimeSlot
@@ -67,7 +67,8 @@ async def open_next(start_time):
"""
# Pre-fetch the new slot data, also populating the table caches
room_data = accountability_rooms.fetch_rows_where(
start_at=start_time
start_at=start_time,
guildid=THIS_SHARD
)
guild_rows = {row.guildid: row for row in room_data}
member_data = accountability_members.fetch_rows_where(
@@ -193,11 +194,30 @@ async def turnover():
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
# We could break up the session starting?
# Move members of the next session over to the session channel
# ---------- Start next session ----------
current_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None
]
slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data}
# Reload the slot members in case they cancelled from another shard
member_data = accountability_members.fetch_rows_where(
slotid=list(slotmap.keys())
) if slotmap else []
slot_memberids = {slotid: [] for slotid in slotmap}
for row in member_data:
slot_memberids[row.slotid].append(row.userid)
reload_tasks = (
slot._reload_members(memberids=slot_memberids[slotid])
for slotid, slot in slotmap.items()
)
await asyncio.gather(
*reload_tasks,
return_exceptions=True
)
# Move members of the next session over to the session channel
movement_tasks = (
mem.member.edit(
voice_channel=slot.channel,
@@ -335,6 +355,7 @@ async def _accountability_system_resume():
open_room_data = accountability_rooms.fetch_rows_where(
closed_at=NULL,
start_at=LEQ(now),
guildid=THIS_SHARD,
_extra="ORDER BY start_at ASC"
)
@@ -450,8 +471,10 @@ async def launch_accountability_system(client):
"""
# Load the AccountabilityGuild cache
guilds = tables.guild_config.fetch_rows_where(
accountability_category=NOTNULL
accountability_category=NOTNULL,
guildid=THIS_SHARD
)
# Further filter out any guilds that we aren't in
[AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)]
await _accountability_system_resume()
asyncio.create_task(_accountability_loop())

View File

@@ -43,7 +43,7 @@ async def cmd_topcoin(ctx):
# Fetch the leaderboard
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.objects['blacklisted_users'])
exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
args = {

View File

@@ -12,6 +12,7 @@ from discord import PartialEmoji
from meta import client
from core import Lion
from data import Row
from data.conditions import THIS_SHARD
from utils.lib import utc_now
from settings import GuildSettings
@@ -584,5 +585,5 @@ def load_reaction_roles(client):
"""
Load the ReactionRoleMessages.
"""
rows = reaction_role_messages.fetch_rows_where()
rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD)
ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows}

View File

@@ -6,6 +6,7 @@ import datetime
import discord
from meta import client
from data.conditions import THIS_SHARD
from settings import GuildSettings
from utils.lib import FieldEnum, strfdelta, utc_now
@@ -283,7 +284,8 @@ class Ticket:
# Get all expiring tickets
expiring_rows = data.tickets.select_where(
ticket_state=TicketState.EXPIRING
ticket_state=TicketState.EXPIRING,
guildid=THIS_SHARD
)
# Create new expiry tasks

View File

@@ -3,6 +3,7 @@ import asyncio
import datetime
import discord
from meta import sharding
from utils.lib import parse_dur, parse_ranges, multiselect_regex
from .module import module
@@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags):
if not rows:
return await ctx.reply("You have no reminders to remove!")
live = Reminder.fetch(*(row.reminderid for row in rows))
live = [Reminder(row.reminderid) for row in rows]
if not ctx.args:
lines = []
@@ -209,6 +210,7 @@ async def cmd_remindme(ctx, flags):
)
# Schedule reminder
if sharding.shard_number == 0:
reminder.schedule()
# Ack
@@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags):
if not rows:
return await ctx.reply("You have no reminders!")
live = Reminder.fetch(*(row.reminderid for row in rows))
live = [Reminder(row.reminderid) for row in rows]
lines = []
num_field = len(str(len(live) - 1))

View File

@@ -1,8 +1,9 @@
import asyncio
import datetime
import logging
import discord
from meta import client
from meta import client, sharding
from utils.lib import strfdur
from .data import reminders
@@ -46,7 +47,10 @@ class Reminder:
cls._live_reminders[reminderid].cancel()
# Remove from data
reminders.delete_where(reminderid=reminderids)
if reminderids:
return reminders.delete_where(reminderid=reminderids)
else:
return []
@property
def data(self):
@@ -134,10 +138,16 @@ class Reminder:
"""
Execute the reminder.
"""
if self.data.userid in client.objects['blacklisted_users']:
if not self.data:
# Reminder deleted elsewhere
return
if self.data.userid in client.user_blacklist():
self.delete(self.reminderid)
return
userid = self.data.userid
# Build the message embed
embed = discord.Embed(
title="You asked me to remind you!",
@@ -155,8 +165,26 @@ class Reminder:
)
)
# Update the reminder data, and reschedule if required
if self.data.interval:
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval)
rows = reminders.update_where(
{'remind_at': next_time},
reminderid=self.reminderid
)
self.schedule()
else:
rows = self.delete(self.reminderid)
if not rows:
# Reminder deleted elsewhere
return
# Send the message, if possible
user = self.user
if not (user := client.get_user(userid)):
try:
user = await client.fetch_user(userid)
except discord.HTTPException:
pass
if user:
try:
await user.send(embed=embed)
@@ -164,17 +192,32 @@ class Reminder:
# Nothing we can really do here. Maybe tell the user about their reminder next time?
pass
# Update the reminder data, and reschedule if required
if self.data.interval:
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval)
reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid)
self.schedule()
else:
self.delete(self.reminderid)
async def reminder_poll(client):
"""
One client/shard must continually poll for new or deleted reminders.
"""
# TODO: Clean this up with database signals or IPC
while True:
await asyncio.sleep(60)
client.log(
"Running new reminder poll.",
context="REMINDERS",
level=logging.DEBUG
)
rids = {row.reminderid for row in reminders.fetch_rows_where()}
to_delete = (rid for rid in Reminder._live_reminders if rid not in rids)
Reminder.delete(*to_delete)
[Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders]
@module.launch_task
async def schedule_reminders(client):
if sharding.shard_number == 0:
rows = reminders.fetch_rows_where()
for row in rows:
Reminder(row.reminderid).schedule()
@@ -182,3 +225,5 @@ async def schedule_reminders(client):
"Scheduled {} reminders.".format(len(rows)),
context="LAUNCH_REMINDERS"
)
if sharding.sharded:
asyncio.create_task(reminder_poll(client))

View File

@@ -5,6 +5,7 @@ import datetime
from cmdClient.lib import SafeCancellation
from meta import client
from data.conditions import THIS_SHARD
from settings import GuildSettings
from .data import rented, rented_members
@@ -276,7 +277,7 @@ class Room:
@module.launch_task
async def load_rented_rooms(client):
rows = rented.fetch_rows_where()
rows = rented.fetch_rows_where(guildid=THIS_SHARD)
for row in rows:
Room(row.channelid).schedule()
client.log(

View File

@@ -6,8 +6,8 @@ import contextlib
import discord
from meta import client
from data.conditions import GEQ
from meta import client, sharding
from data.conditions import GEQ, THIS_SHARD
from core.data import lions
from utils.lib import strfdur
from settings import GuildSettings
@@ -54,12 +54,16 @@ async def update_study_badges(full=False):
# Retrieve member rows with out of date study badges
if not full and client.appdata.last_study_badge_scan is not None:
# TODO: _extra here is a hack to cover for inflexible conditionals
update_rows = new_study_badges.select_where(
guildid=THIS_SHARD,
_timestamp=GEQ(client.appdata.last_study_badge_scan or 0),
_extra="OR session_start IS NOT NULL"
_extra="OR session_start IS NOT NULL AND (guildid >> 22) %% {} = {}".format(
sharding.shard_count, sharding.shard_number
)
)
else:
update_rows = new_study_badges.select_where()
update_rows = new_study_badges.select_where(guildid=THIS_SHARD)
if not update_rows:
client.appdata.last_study_badge_scan = datetime.datetime.utcnow()

View File

@@ -59,7 +59,7 @@ async def cmd_stats(ctx):
# Leaderboard ranks
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.objects['blacklisted_users'])
exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
if target.id in exclude:
time_rank = None

View File

@@ -40,7 +40,7 @@ async def cmd_top(ctx):
# Fetch the leaderboard
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.objects['blacklisted_users'])
exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
args = {

View File

@@ -7,6 +7,7 @@ from collections import defaultdict
from utils.lib import utc_now
from data import tables
from data.conditions import THIS_SHARD
from core import Lion
from meta import client
@@ -298,7 +299,7 @@ async def session_voice_tracker(client, member, before, after):
pending.cancel()
if after.channel:
blacklist = client.objects['blacklisted_users']
blacklist = client.user_blacklist()
guild_blacklist = client.objects['ignored_members'][guild.id]
untracked = untracked_channels.get(guild.id).data
start_session = (
@@ -398,7 +399,7 @@ async def _init_session_tracker(client):
ended = 0
# Grab all ongoing sessions from data
rows = current_sessions.fetch_rows_where()
rows = current_sessions.fetch_rows_where(guildid=THIS_SHARD)
# Iterate through, resume or end as needed
for row in rows:

View File

@@ -47,7 +47,7 @@ def _scan(guild):
members = itertools.chain(*channel_members)
# TODO filter out blacklisted users
blacklist = client.objects['blacklisted_users']
blacklist = client.user_blacklist()
guild_blacklist = client.objects['ignored_members'][guild.id]
for member in members:

View File

@@ -7,6 +7,8 @@ import discord
from cmdClient.checks import is_owner
from cmdClient.lib import ResponseTimedOut
from meta.sharding import sharded
from .module import module
@@ -26,14 +28,14 @@ async def cmd_guildblacklist(ctx, flags):
Description:
View, add, or remove guilds from the blacklist.
"""
blacklist = ctx.client.objects['blacklisted_guilds']
blacklist = ctx.client.guild_blacklist()
if ctx.args:
# guildid parsing
items = [item.strip() for item in ctx.args.split(',')]
if any(not item.isdigit() for item in items):
return await ctx.error_reply(
"Please provide guilds as comma seprated guild ids."
"Please provide guilds as comma separated guild ids."
)
guildids = set(int(item) for item in items)
@@ -80,9 +82,18 @@ async def cmd_guildblacklist(ctx, flags):
insert_keys=('guildid', 'ownerid', 'reason')
)
# Check if we are in any of these guilds
to_leave = (ctx.client.get_guild(guildid) for guildid in to_add)
to_leave = [guild for guild in to_leave if guild is not None]
# Leave freshly blacklisted guilds, accounting for shards
to_leave = []
for guildid in to_add:
guild = ctx.client.get_guild(guildid)
if not guild and sharded:
try:
guild = await ctx.client.fetch_guild(guildid)
except discord.HTTPException:
pass
if guild:
to_leave.append(guild)
for guild in to_leave:
await guild.leave()
@@ -102,9 +113,8 @@ async def cmd_guildblacklist(ctx, flags):
)
# Refresh the cached blacklist after modification
ctx.client.objects['blacklisted_guilds'] = set(
row['guildid'] for row in ctx.client.data.global_guild_blacklist.select_where()
)
ctx.client.guild_blacklist.cache_clear()
ctx.client.guild_blacklist()
else:
# Display the current blacklist
# First fetch the full blacklist data
@@ -183,7 +193,7 @@ async def cmd_userblacklist(ctx, flags):
Description:
View, add, or remove users from the blacklist.
"""
blacklist = ctx.client.objects['blacklisted_users']
blacklist = ctx.client.user_blacklist()
if ctx.args:
# userid parsing
@@ -245,9 +255,8 @@ async def cmd_userblacklist(ctx, flags):
)
# Refresh the cached blacklist after modification
ctx.client.objects['blacklisted_users'] = set(
row['userid'] for row in ctx.client.data.global_user_blacklist.select_where()
)
ctx.client.user_blacklist.cache_clear()
ctx.client.user_blacklist()
else:
# Display the current blacklist
# First fetch the full blacklist data

View File

@@ -13,19 +13,13 @@ async def update_status():
# TODO: Make globally configurable and saveable
global _last_update
if time.time() - _last_update < 30:
if time.time() - _last_update < 60:
return
_last_update = time.time()
student_count = sum(
len(ch.members)
for guild in client.guilds
for ch in guild.voice_channels
)
room_count = sum(
len([vc for vc in guild.voice_channels if vc.members])
for guild in client.guilds
student_count, room_count = client.data.current_sessions.select_one_where(
select_columns=("COUNT(*) AS studying_count", "COUNT(DISTINCT(channelid)) AS channel_count"),
)
status = "{} students in {} study rooms!".format(student_count, room_count)

View File

@@ -7,6 +7,7 @@ from core import Lion
from settings import GuildSettings
from meta import client
from data import NULL, tables
from data.conditions import THIS_SHARD
from .module import module
from .data import workout_sessions
@@ -170,7 +171,7 @@ async def workout_voice_tracker(client, member, before, after):
if member.bot:
return
if member.id in client.objects['blacklisted_users']:
if member.id in client.user_blacklist():
return
if member.id in client.objects['ignored_members'][member.guild.id]:
return
@@ -226,7 +227,8 @@ async def load_workouts(client):
client.objects['current_workouts'] = {} # (guildid, userid) -> Row
# Process any incomplete workouts
workouts = workout_sessions.fetch_rows_where(
duration=NULL
duration=NULL,
guildid=THIS_SHARD
)
count = 0
for workout in workouts:

View File

@@ -1,6 +1,7 @@
[DEFAULT]
log_file = bot.log
log_channel =
error_channel =
guild_log_channel =
prefix = !
@@ -10,4 +11,6 @@ owners = 413668234269818890, 389399222400712714
database = dbname=lionbot
data_appid = LionBot
shard_count = 1
lion_sync_period = 60