Merge branch 'staging' into feature-timer

This commit is contained in:
2022-01-01 07:07:25 +02:00
63 changed files with 2501 additions and 432 deletions

140
.gitignore vendored Normal file
View File

@@ -0,0 +1,140 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
config/**

View File

@@ -13,7 +13,7 @@ class LionCommand(Command):
""" """
Subclass to allow easy attachment of custom hooks and structure to commands. Subclass to allow easy attachment of custom hooks and structure to commands.
""" """
... allow_before_ready = False
class LionModule(Module): class LionModule(Module):
@@ -72,25 +72,38 @@ class LionModule(Module):
""" """
Lion pre-command hook. Lion pre-command hook.
""" """
if not self.ready and not ctx.cmd.allow_before_ready:
try:
await ctx.embed_reply(
"I am currently restarting! Please try again in a couple of minutes."
)
except discord.HTTPException:
pass
raise SafeCancellation(details="Module '{}' is not ready.".format(self.name))
# Check global user blacklist # Check global user blacklist
if ctx.author.id in ctx.client.objects['blacklisted_users']: if ctx.author.id in ctx.client.user_blacklist():
raise SafeCancellation raise SafeCancellation(details='User is blacklisted.')
if ctx.guild: if ctx.guild:
# Check that the channel and guild still exists
if not ctx.client.get_guild(ctx.guild.id) or not ctx.guild.get_channel(ctx.ch.id):
raise SafeCancellation(details='Command channel is no longer reachable.')
# Check global guild blacklist # Check global guild blacklist
if ctx.guild.id in ctx.client.objects['blacklisted_guilds']: if ctx.guild.id in ctx.client.guild_blacklist():
raise SafeCancellation raise SafeCancellation(details='Guild is blacklisted.')
# Check guild's own member blacklist # Check guild's own member blacklist
if ctx.author.id in ctx.client.objects['ignored_members'][ctx.guild.id]: if ctx.author.id in ctx.client.objects['ignored_members'][ctx.guild.id]:
raise SafeCancellation raise SafeCancellation(details='User is ignored in this guild.')
# Check channel permissions are sane # Check channel permissions are sane
if not ctx.ch.permissions_for(ctx.guild.me).send_messages: if not ctx.ch.permissions_for(ctx.guild.me).send_messages:
raise SafeCancellation raise SafeCancellation(details='I cannot send messages in this channel.')
if not ctx.ch.permissions_for(ctx.guild.me).embed_links: if not ctx.ch.permissions_for(ctx.guild.me).embed_links:
await ctx.reply("I need permission to send embeds in this channel before I can run any commands!") await ctx.reply("I need permission to send embeds in this channel before I can run any commands!")
raise SafeCancellation raise SafeCancellation(details='I cannot send embeds in this channel.')
# Start typing # Start typing
await ctx.ch.trigger_typing() await ctx.ch.trigger_typing()

View File

@@ -1,2 +1,2 @@
CONFIG_FILE = "config/bot.conf" CONFIG_FILE = "config/bot.conf"
DATA_VERSION = 5 DATA_VERSION = 6

View File

@@ -1,9 +1,8 @@
""" """
Guild, user, and member blacklists. Guild, user, and member blacklists.
NOTE: The pre-loading methods are not shard-optimised.
""" """
from collections import defaultdict from collections import defaultdict
import cachetools.func
from data import tables from data import tables
from meta import client from meta import client
@@ -11,32 +10,22 @@ from meta import client
from .module import module from .module import module
@module.init_task @cachetools.func.ttl_cache(ttl=300)
def load_guild_blacklist(client): def guild_blacklist():
""" """
Load the blacklisted guilds. Get the guild blacklist
""" """
rows = tables.global_guild_blacklist.select_where() rows = tables.global_guild_blacklist.select_where()
client.objects['blacklisted_guilds'] = set(row['guildid'] for row in rows) return set(row['guildid'] for row in rows)
if rows:
client.log(
"Loaded {} blacklisted guilds.".format(len(rows)),
context="GUILD_BLACKLIST"
)
@module.init_task @cachetools.func.ttl_cache(ttl=300)
def load_user_blacklist(client): def user_blacklist():
""" """
Load the blacklisted users. Get the global user blacklist.
""" """
rows = tables.global_user_blacklist.select_where() rows = tables.global_user_blacklist.select_where()
client.objects['blacklisted_users'] = set(row['userid'] for row in rows) return set(row['userid'] for row in rows)
if rows:
client.log(
"Loaded {} globally blacklisted users.".format(len(rows)),
context="USER_BLACKLIST"
)
@module.init_task @module.init_task
@@ -62,18 +51,20 @@ def load_ignored_members(client):
) )
@module.init_task
def attach_client_blacklists(client):
client.guild_blacklist = guild_blacklist
client.user_blacklist = user_blacklist
@module.launch_task @module.launch_task
async def leave_blacklisted_guilds(client): async def leave_blacklisted_guilds(client):
""" """
Launch task to leave any blacklisted guilds we are in. Launch task to leave any blacklisted guilds we are in.
Assumes that the blacklisted guild list has been initialised.
""" """
# Cache to avoic repeated lookups
blacklisted = client.objects['blacklisted_guilds']
to_leave = [ to_leave = [
guild for guild in client.guilds guild for guild in client.guilds
if guild.id in blacklisted if guild.id in guild_blacklist()
] ]
for guild in to_leave: for guild in to_leave:
@@ -92,7 +83,8 @@ async def check_guild_blacklist(client, guild):
Guild join event handler to check whether the guild is blacklisted. Guild join event handler to check whether the guild is blacklisted.
If so, leaves the guild. If so, leaves the guild.
""" """
if guild.id in client.objects['blacklisted_guilds']: # First refresh the blacklist cache
if guild.id in guild_blacklist():
await guild.leave() await guild.leave()
client.log( client.log(
"Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id), "Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id),

View File

@@ -20,46 +20,21 @@ user_config = RowTable(
) )
@user_config.save_query
def add_pending(pending):
"""
pending:
List of tuples of the form `(userid, pending_coins, pending_time)`.
"""
with lions.conn:
cursor = lions.conn.cursor()
data = execute_values(
cursor,
"""
UPDATE members
SET
coins = coins + t.coin_diff,
tracked_time = tracked_time + t.time_diff
FROM
(VALUES %s)
AS
t (guildid, userid, coin_diff, time_diff)
WHERE
members.guildid = t.guildid
AND
members.userid = t.userid
RETURNING *
""",
pending,
fetch=True
)
return lions._make_rows(*data)
guild_config = RowTable( guild_config = RowTable(
'guild_config', 'guild_config',
('guildid', 'admin_role', 'mod_role', 'event_log_channel', 'alert_channel', ('guildid', 'admin_role', 'mod_role', 'event_log_channel', 'mod_log_channel', 'alert_channel',
'studyban_role', 'max_study_bans',
'min_workout_length', 'workout_reward', 'min_workout_length', 'workout_reward',
'max_tasks', 'task_reward', 'task_reward_limit', 'max_tasks', 'task_reward', 'task_reward_limit',
'study_hourly_reward', 'study_hourly_live_bonus', 'study_hourly_reward', 'study_hourly_live_bonus', 'daily_study_cap',
'study_ban_role', 'max_study_bans'), 'renting_price', 'renting_category', 'renting_cap', 'renting_role', 'renting_sync_perms',
'accountability_category', 'accountability_lobby', 'accountability_bonus',
'accountability_reward', 'accountability_price',
'video_studyban', 'video_grace_period',
'greeting_channel', 'greeting_message', 'returning_message',
'starting_funds', 'persist_roles'),
'guildid', 'guildid',
cache=TTLCache(1000, ttl=60*5) cache=TTLCache(2500, ttl=60*5)
) )
unranked_roles = Table('unranked_roles') unranked_roles = Table('unranked_roles')
@@ -72,6 +47,7 @@ lions = RowTable(
('guildid', 'userid', ('guildid', 'userid',
'tracked_time', 'coins', 'tracked_time', 'coins',
'workout_count', 'last_workout_start', 'workout_count', 'last_workout_start',
'revision_mute_count',
'last_study_badgeid', 'last_study_badgeid',
'video_warned', 'video_warned',
'_timestamp' '_timestamp'
@@ -81,9 +57,66 @@ lions = RowTable(
attach_as='lions' attach_as='lions'
) )
@lions.save_query
def add_pending(pending):
"""
pending:
List of tuples of the form `(guildid, userid, pending_coins)`.
"""
with lions.conn:
cursor = lions.conn.cursor()
data = execute_values(
cursor,
"""
UPDATE members
SET
coins = coins + t.coin_diff
FROM
(VALUES %s)
AS
t (guildid, userid, coin_diff)
WHERE
members.guildid = t.guildid
AND
members.userid = t.userid
RETURNING *
""",
pending,
fetch=True
)
return lions._make_rows(*data)
lion_ranks = Table('member_ranks', attach_as='lion_ranks') lion_ranks = Table('member_ranks', attach_as='lion_ranks')
@lions.save_query
def get_member_rank(guildid, userid, untracked):
"""
Get the time and coin ranking for the given member, ignoring the provided untracked members.
"""
with lions.conn as conn:
with conn.cursor() as curs:
curs.execute(
"""
SELECT
time_rank, coin_rank
FROM (
SELECT
userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals
WHERE
guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s
""",
(guildid, tuple(untracked), userid)
)
return curs.fetchone() or (None, None)
global_guild_blacklist = Table('global_guild_blacklist') global_guild_blacklist = Table('global_guild_blacklist')
global_user_blacklist = Table('global_user_blacklist') global_user_blacklist = Table('global_user_blacklist')
ignored_members = Table('ignored_members') ignored_members = Table('ignored_members')

View File

@@ -1,4 +1,5 @@
import pytz import pytz
from datetime import datetime, timedelta
from meta import client from meta import client
from data import tables as tb from data import tables as tb
@@ -11,7 +12,7 @@ class Lion:
Mostly acts as a transparent interface to the corresponding Row, Mostly acts as a transparent interface to the corresponding Row,
but also adds some transaction caching logic to `coins` and `tracked_time`. but also adds some transaction caching logic to `coins` and `tracked_time`.
""" """
__slots__ = ('guildid', 'userid', '_pending_coins', '_pending_time', '_member') __slots__ = ('guildid', 'userid', '_pending_coins', '_member')
# Members with pending transactions # Members with pending transactions
_pending = {} # userid -> User _pending = {} # userid -> User
@@ -24,7 +25,6 @@ class Lion:
self.userid = userid self.userid = userid
self._pending_coins = 0 self._pending_coins = 0
self._pending_time = 0
self._member = None self._member = None
@@ -41,6 +41,7 @@ class Lion:
if key in cls._lions: if key in cls._lions:
return cls._lions[key] return cls._lions[key]
else: else:
# TODO: Debug log
lion = tb.lions.fetch(key) lion = tb.lions.fetch(key)
if not lion: if not lion:
tb.lions.create_row( tb.lions.create_row(
@@ -77,23 +78,134 @@ class Lion:
@property @property
def settings(self): def settings(self):
""" """
The UserSettings object for this user. The UserSettings interface for this member.
""" """
return UserSettings(self.userid) return UserSettings(self.userid)
@property
def guild_settings(self):
"""
The GuildSettings interface for this member.
"""
return GuildSettings(self.guildid)
@property @property
def time(self): def time(self):
""" """
Amount of time the user has spent studying, accounting for pending values. Amount of time the user has spent studying, accounting for a current session.
""" """
return int(self.data.tracked_time + self._pending_time) # Base time from cached member data
time = self.data.tracked_time
# Add current session time if it exists
if session := self.session:
time += session.duration
return int(time)
@property @property
def coins(self): def coins(self):
""" """
Number of coins the user has, accounting for the pending value. Number of coins the user has, accounting for the pending value and current session.
""" """
return int(self.data.coins + self._pending_coins) # Base coin amount from cached member data
coins = self.data.coins
# Add pending coin amount
coins += self._pending_coins
# Add current session coins if applicable
if session := self.session:
coins += session.coins_earned
return int(coins)
@property
def session(self):
"""
The current study session the user is in, if any.
"""
if 'sessions' not in client.objects:
raise ValueError("Cannot retrieve session before Study module is initialised!")
return client.objects['sessions'][self.guildid].get(self.userid, None)
@property
def timezone(self):
"""
The user's configured timezone.
Shortcut to `Lion.settings.timezone.value`.
"""
return self.settings.timezone.value
@property
def day_start(self):
"""
A timezone aware datetime representing the start of the user's day (in their configured timezone).
NOTE: This might not be accurate over DST boundaries.
"""
now = datetime.now(tz=self.timezone)
return now.replace(hour=0, minute=0, second=0, microsecond=0)
@property
def day_timestamp(self):
"""
EPOCH timestamp representing the current day for the user.
NOTE: This is the timestamp of the start of the current UTC day with the same date as the user's day.
This is *not* the start of the current user's day, either in UTC or their own timezone.
This may also not be the start of the current day in UTC (consider 23:00 for a user in UTC-2).
"""
now = datetime.now(tz=self.timezone)
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return int(day_start.replace(tzinfo=pytz.utc).timestamp())
@property
def week_timestamp(self):
"""
EPOCH timestamp representing the current week for the user.
"""
now = datetime.now(tz=self.timezone)
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
week_start = day_start - timedelta(days=day_start.weekday())
return int(week_start.replace(tzinfo=pytz.utc).timestamp())
@property
def month_timestamp(self):
"""
EPOCH timestamp representing the current month for the user.
"""
now = datetime.now(tz=self.timezone)
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return int(month_start.replace(tzinfo=pytz.utc).timestamp())
@property
def remaining_in_day(self):
return ((self.day_start + timedelta(days=1)) - datetime.now(self.timezone)).total_seconds()
@property
def studied_today(self):
"""
The amount of time, in seconds, that the member has studied today.
Extracted from the session history.
"""
return tb.session_history.queries.study_time_since(self.guildid, self.userid, self.day_start)
@property
def remaining_study_today(self):
"""
Maximum remaining time (in seconds) this member can study today.
May not account for DST boundaries and leap seconds.
"""
studied_today = self.studied_today
study_cap = self.guild_settings.daily_study_cap.value
remaining_in_day = self.remaining_in_day
if remaining_in_day >= (study_cap - studied_today):
remaining = study_cap - studied_today
else:
remaining = remaining_in_day + study_cap
return remaining
def localize(self, naive_utc_dt): def localize(self, naive_utc_dt):
""" """
@@ -111,15 +223,6 @@ class Lion:
if flush: if flush:
self.flush() self.flush()
def addTime(self, amount, flush=True):
"""
Add time to a user (in seconds), optionally storing the transaction in pending.
"""
self._pending_time += amount
self._pending[self.key] = self
if flush:
self.flush()
def flush(self): def flush(self):
""" """
Flush any pending transactions to the database. Flush any pending transactions to the database.
@@ -137,7 +240,7 @@ class Lion:
if lions: if lions:
# Build userid to pending coin map # Build userid to pending coin map
pending = [ pending = [
(lion.guildid, lion.userid, int(lion._pending_coins), int(lion._pending_time)) (lion.guildid, lion.userid, int(lion._pending_coins))
for lion in lions for lion in lions
] ]
@@ -147,5 +250,4 @@ class Lion:
# Cleanup pending users # Cleanup pending users
for lion in lions: for lion in lions:
lion._pending_coins -= int(lion._pending_coins) lion._pending_coins -= int(lion._pending_coins)
lion._pending_time -= int(lion._pending_time)
cls._pending.pop(lion.key, None) cls._pending.pop(lion.key, None)

View File

@@ -61,9 +61,10 @@ async def preload_studying_members(client):
""" """
userids = list(set(member.id for guild in client.guilds for ch in guild.voice_channels for member in ch.members)) userids = list(set(member.id for guild in client.guilds for ch in guild.voice_channels for member in ch.members))
if userids: if userids:
rows = client.data.lions.fetch_rows_where(userid=userids) users = client.data.user_config.fetch_rows_where(userid=userids)
members = client.data.lions.fetch_rows_where(userid=userids)
client.log( client.log(
"Preloaded member data for {} members.".format(len(rows)), "Preloaded data for {} user with {} members.".format(len(users), len(members)),
context="CORE_LOADING" context="CORE_LOADING"
) )

View File

@@ -1,5 +1,5 @@
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa
from .connection import conn # noqa from .connection import conn # noqa
from .formatters import UpdateValue, UpdateValueAdd # noqa from .formatters import UpdateValue, UpdateValueAdd # noqa
from .interfaces import Table, RowTable, Row, tables # noqa from .interfaces import Table, RowTable, Row, tables # noqa
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa

View File

@@ -1,5 +1,7 @@
from .connection import _replace_char from .connection import _replace_char
from meta import sharding
class Condition: class Condition:
""" """
@@ -70,5 +72,21 @@ class Constant(Condition):
conditions.append("{} {}".format(key, self.value)) conditions.append("{} {}".format(key, self.value))
class SHARDID(Condition):
__slots__ = ('shardid', 'shard_count')
def __init__(self, shardid, shard_count):
self.shardid = shardid
self.shard_count = shard_count
def apply(self, key, values, conditions):
if self.shard_count > 1:
conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
values.append(self.shardid)
THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
NULL = Constant('IS NULL') NULL = Constant('IS NULL')
NOTNULL = Constant('IS NOT NULL') NOTNULL = Constant('IS NOT NULL')

View File

@@ -45,10 +45,10 @@ class Table:
Intended to be subclassed to provide more derivative access for specific tables. Intended to be subclassed to provide more derivative access for specific tables.
""" """
conn = conn conn = conn
queries = DotDict()
def __init__(self, name, attach_as=None): def __init__(self, name, attach_as=None):
self.name = name self.name = name
self.queries = DotDict()
tables[attach_as or name] = self tables[attach_as or name] = self
@_connection_guard @_connection_guard

View File

@@ -1,4 +1,4 @@
from meta import client, conf, log from meta import client, conf, log, sharding
from data import tables from data import tables
@@ -7,7 +7,12 @@ import core # noqa
import modules # noqa import modules # noqa
# Load and attach app specific data # Load and attach app specific data
client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid']) if sharding.sharded:
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
else:
appname = conf.bot['data_appid']
client.appdata = core.data.meta.fetch_or_create(appname)
client.data = tables client.data = tables
# Initialise all modules # Initialise all modules

View File

@@ -1,3 +1,5 @@
from .logger import log, logger
from .client import client from .client import client
from .config import conf from .config import conf
from .logger import log, logger from .args import args
from . import sharding

19
bot/meta/args.py Normal file
View File

@@ -0,0 +1,19 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file.")
parser.add_argument('--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable.")
args = parser.parse_args()

View File

@@ -1,16 +1,19 @@
from discord import Intents from discord import Intents
from cmdClient.cmdClient import cmdClient from cmdClient.cmdClient import cmdClient
from .config import Conf from .config import conf
from .sharding import shard_number, shard_count
from constants import CONFIG_FILE
# Initialise config
conf = Conf(CONFIG_FILE)
# Initialise client # Initialise client
owners = [int(owner) for owner in conf.bot.getlist('owners')] owners = [int(owner) for owner in conf.bot.getlist('owners')]
intents = Intents.all() intents = Intents.all()
intents.presences = False intents.presences = False
client = cmdClient(prefix=conf.bot['prefix'], owners=owners, intents=intents) client = cmdClient(
prefix=conf.bot['prefix'],
owners=owners,
intents=intents,
shard_id=shard_number,
shard_count=shard_count
)
client.conf = conf client.conf = conf

View File

@@ -1,9 +1,6 @@
import configparser as cfgp import configparser as cfgp
from .args import args
conf = None # type: Conf
CONF_FILE = "bot/bot.conf"
class Conf: class Conf:
@@ -57,3 +54,6 @@ class Conf:
def write(self): def write(self):
with open(self.configfile, 'w') as conffile: with open(self.configfile, 'w') as conffile:
self.config.write(conffile) self.config.write(conffile)
conf = Conf(args.config)

View File

@@ -9,11 +9,18 @@ from utils.lib import mail, split_text
from .client import client from .client import client
from .config import conf from .config import conf
from . import sharding
# Setup the logger # Setup the logger
logger = logging.getLogger() logger = logging.getLogger()
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{') log_fmt = logging.Formatter(
fmt=('[{asctime}][{levelname:^8}]' +
'[SHARD {}]'.format(sharding.shard_number) +
' {message}'),
datefmt='%d/%m | %H:%M:%S',
style='{'
)
# term_handler = logging.StreamHandler(sys.stdout) # term_handler = logging.StreamHandler(sys.stdout)
# term_handler.setFormatter(log_fmt) # term_handler.setFormatter(log_fmt)
# logger.addHandler(term_handler) # logger.addHandler(term_handler)
@@ -77,7 +84,11 @@ async def live_log(message, context, level):
log_chid = conf.bot.getint('log_channel') log_chid = conf.bot.getint('log_channel')
# Generate the log messages # Generate the log messages
header = "[{}][{}]".format(logging.getLevelName(level), str(context)) if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900: if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False) blocks = split_text(message, blocksize=1900, code=False)
else: else:

9
bot/meta/sharding.py Normal file
View File

@@ -0,0 +1,9 @@
from .args import args
from .config import conf
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)

View File

@@ -3,6 +3,7 @@ from .guild_admin import *
from .meta import * from .meta import *
from .economy import * from .economy import *
from .study import * from .study import *
from .stats import *
from .user_config import * from .user_config import *
from .workout import * from .workout import *
from .todo import * from .todo import *

View File

@@ -69,7 +69,8 @@ class TimeSlot:
_everyone_overwrite = discord.PermissionOverwrite( _everyone_overwrite = discord.PermissionOverwrite(
view_channel=False, view_channel=False,
connect=False connect=False,
speak=False
) )
happy_lion = "https://media.discordapp.net/stickers/898266283559227422.png" happy_lion = "https://media.discordapp.net/stickers/898266283559227422.png"
@@ -89,7 +90,6 @@ class TimeSlot:
@property @property
def open_embed(self): def open_embed(self):
# TODO Consider adding hint to footer
timestamp = int(self.start_time.timestamp()) timestamp = int(self.start_time.timestamp())
embed = discord.Embed( embed = discord.Embed(
@@ -218,6 +218,9 @@ class TimeSlot:
""" """
Load data and update applicable caches. Load data and update applicable caches.
""" """
if not self.guild:
return self
# Load setting data # Load setting data
self.category = GuildSettings(self.guild.id).accountability_category.value self.category = GuildSettings(self.guild.id).accountability_category.value
self.lobby = GuildSettings(self.guild.id).accountability_lobby.value self.lobby = GuildSettings(self.guild.id).accountability_lobby.value
@@ -228,7 +231,7 @@ class TimeSlot:
self.channel = self.guild.get_channel(self.data.channelid) self.channel = self.guild.get_channel(self.data.channelid)
# Load message # Load message
if self.data.messageid: if self.data.messageid and self.lobby:
self.message = discord.PartialMessage( self.message = discord.PartialMessage(
channel=self.lobby, channel=self.lobby,
id=self.data.messageid id=self.data.messageid
@@ -243,6 +246,34 @@ class TimeSlot:
return self return self
async def _reload_members(self, memberids=None):
"""
Reload the timeslot members from the provided list, or data.
Also updates the channel overwrites if required.
To be used before the session has started.
"""
if self.data:
if memberids is None:
member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid)
memberids = [row.userid for row in member_rows]
self.members = members = {
memberid: SlotMember(self.data.slotid, memberid, self.guild)
for memberid in memberids
}
if self.channel:
# Check and potentially update overwrites
current_overwrites = self.channel.overwrites
overwrites = {
mem.member: self._member_overwrite
for mem in members.values()
if mem.member
}
overwrites[self.guild.default_role] = self._everyone_overwrite
if current_overwrites != overwrites:
await self.channel.edit(overwrites=overwrites)
def _refresh(self): def _refresh(self):
""" """
Refresh the stored data row and reload. Refresh the stored data row and reload.
@@ -389,6 +420,7 @@ class TimeSlot:
pass pass
# Reward members appropriately # Reward members appropriately
if self.guild:
guild_settings = GuildSettings(self.guild.id) guild_settings = GuildSettings(self.guild.id)
reward = guild_settings.accountability_reward.value reward = guild_settings.accountability_reward.value
if all(mem.has_attended for mem in self.members.values()): if all(mem.has_attended for mem in self.members.values()):

View File

@@ -39,6 +39,7 @@ def time_format(time):
time.timestamp() + 3600, time.timestamp() + 3600,
) )
user_locks = {} # Map userid -> ctx user_locks = {} # Map userid -> ctx
@@ -229,7 +230,10 @@ async def cmd_rooms(ctx):
start_time + datetime.timedelta(hours=n) start_time + datetime.timedelta(hours=n)
for n in range(1, 25) for n in range(1, 25)
) )
times = [time for time in times if time not in already_joined_times] times = [
time for time in times
if time not in already_joined_times and (time - utc_now()).total_seconds() > 660
]
lines = [ lines = [
"`[{num:>2}]` | `{count:>{count_pad}}` attending | {time}".format( "`[{num:>2}]` | `{count:>{count_pad}}` attending | {time}".format(
num=i, num=i,
@@ -255,7 +259,7 @@ async def cmd_rooms(ctx):
await ctx.cancellable( await ctx.cancellable(
out_msg, out_msg,
cancel_message="Booking menu cancelled, no sessions were booked.", cancel_message="Booking menu cancelled, no sessions were booked.",
timeout=70 timeout=60
) )
def check(msg): def check(msg):
@@ -265,7 +269,7 @@ async def cmd_rooms(ctx):
with ensure_exclusive(ctx): with ensure_exclusive(ctx):
try: try:
message = await ctx.client.wait_for('message', check=check, timeout=60) message = await ctx.client.wait_for('message', check=check, timeout=30)
except asyncio.TimeoutError: except asyncio.TimeoutError:
try: try:
await out_msg.edit( await out_msg.edit(
@@ -325,6 +329,7 @@ async def cmd_rooms(ctx):
) )
# Handle case where the slot has already opened # Handle case where the slot has already opened
# TODO: Fix this, doesn't always work
aguild = AGuild.cache.get(ctx.guild.id, None) aguild = AGuild.cache.get(ctx.guild.id, None)
if aguild: if aguild:
if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book: if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book:

View File

@@ -10,7 +10,7 @@ from discord.utils import sleep_until
from meta import client from meta import client
from utils.interactive import discord_shield from utils.interactive import discord_shield
from data import NULL, NOTNULL, tables from data import NULL, NOTNULL, tables
from data.conditions import LEQ from data.conditions import LEQ, THIS_SHARD
from settings import GuildSettings from settings import GuildSettings
from .TimeSlot import TimeSlot from .TimeSlot import TimeSlot
@@ -67,7 +67,8 @@ async def open_next(start_time):
""" """
# Pre-fetch the new slot data, also populating the table caches # Pre-fetch the new slot data, also populating the table caches
room_data = accountability_rooms.fetch_rows_where( room_data = accountability_rooms.fetch_rows_where(
start_at=start_time start_at=start_time,
guildid=THIS_SHARD
) )
guild_rows = {row.guildid: row for row in room_data} guild_rows = {row.guildid: row for row in room_data}
member_data = accountability_members.fetch_rows_where( member_data = accountability_members.fetch_rows_where(
@@ -193,11 +194,30 @@ async def turnover():
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members. # TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
# We could break up the session starting? # We could break up the session starting?
# Move members of the next session over to the session channel # ---------- Start next session ----------
current_slots = [ current_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values() aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None if aguild.current_slot is not None
] ]
slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data}
# Reload the slot members in case they cancelled from another shard
member_data = accountability_members.fetch_rows_where(
slotid=list(slotmap.keys())
) if slotmap else []
slot_memberids = {slotid: [] for slotid in slotmap}
for row in member_data:
slot_memberids[row.slotid].append(row.userid)
reload_tasks = (
slot._reload_members(memberids=slot_memberids[slotid])
for slotid, slot in slotmap.items()
)
await asyncio.gather(
*reload_tasks,
return_exceptions=True
)
# Move members of the next session over to the session channel
movement_tasks = ( movement_tasks = (
mem.member.edit( mem.member.edit(
voice_channel=slot.channel, voice_channel=slot.channel,
@@ -335,6 +355,7 @@ async def _accountability_system_resume():
open_room_data = accountability_rooms.fetch_rows_where( open_room_data = accountability_rooms.fetch_rows_where(
closed_at=NULL, closed_at=NULL,
start_at=LEQ(now), start_at=LEQ(now),
guildid=THIS_SHARD,
_extra="ORDER BY start_at ASC" _extra="ORDER BY start_at ASC"
) )
@@ -374,14 +395,15 @@ async def _accountability_system_resume():
None, mow.slotid, mow.userid) None, mow.slotid, mow.userid)
for mow in slot_members[row.slotid] if mow.last_joined_at for mow in slot_members[row.slotid] if mow.last_joined_at
) )
if client.get_guild(row.guildid):
slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load( slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load(
memberids=[mow.userid for mow in slot_members[row.slotid]] memberids=[mow.userid for mow in slot_members[row.slotid]]
) )
row.closed_at = now
try: try:
await slot.close() await slot.close()
except discord.HTTPException: except discord.HTTPException:
pass pass
row.closed_at = now
# Load the in-progress room data # Load the in-progress room data
if current_room_data: if current_room_data:
@@ -449,9 +471,11 @@ async def launch_accountability_system(client):
""" """
# Load the AccountabilityGuild cache # Load the AccountabilityGuild cache
guilds = tables.guild_config.fetch_rows_where( guilds = tables.guild_config.fetch_rows_where(
accountability_category=NOTNULL accountability_category=NOTNULL,
guildid=THIS_SHARD
) )
[AccountabilityGuild(guild.guildid) for guild in guilds] # Further filter out any guilds that we aren't in
[AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)]
await _accountability_system_resume() await _accountability_system_resume()
asyncio.create_task(_accountability_loop()) asyncio.create_task(_accountability_loop())

View File

@@ -43,22 +43,18 @@ async def cmd_topcoin(ctx):
# Fetch the leaderboard # Fetch the leaderboard
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.objects['blacklisted_users']) exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
args = {
'guildid': ctx.guild.id,
'select_columns': ('userid', 'total_coins::INTEGER'),
'_extra': "AND total_coins > 0 ORDER BY total_coins DESC " + ("LIMIT 100" if top100 else "")
}
if exclude: if exclude:
user_data = tables.lions.select_where( args['userid'] = data.NOT(list(exclude))
guildid=ctx.guild.id,
userid=data.NOT(list(exclude)), user_data = tables.members_totals.select_where(**args)
select_columns=('userid', 'coins'),
_extra="AND coins > 0 ORDER BY coins DESC " + ("LIMIT 100" if top100 else "")
)
else:
user_data = tables.lions.select_where(
guildid=ctx.guild.id,
select_columns=('userid', 'coins'),
_extra="AND coins > 0 ORDER BY coins DESC " + ("LIMIT 100" if top100 else "")
)
# Quit early if the leaderboard is empty # Quit early if the leaderboard is empty
if not user_data: if not user_data:

View File

@@ -4,3 +4,4 @@ from . import guild_config
from . import statreset from . import statreset
from . import new_members from . import new_members
from . import reaction_roles from . import reaction_roles
from . import economy

View File

@@ -0,0 +1,3 @@
from ..module import module
from . import set_coins

View File

@@ -0,0 +1,104 @@
import discord
import datetime
from wards import guild_admin
from settings import GuildSettings
from core import Lion
from ..module import module
POSTGRES_INT_MAX = 2147483647
@module.cmd(
"set_coins",
group="Guild Admin",
desc="Set coins on a member."
)
@guild_admin()
async def cmd_set(ctx):
"""
Usage``:
{prefix}set_coins <user mention> <amount>
Description:
Sets the given number of coins on the mentioned user.
If a number greater than 0 is mentioned, will add coins.
If a number less than 0 is mentioned, will remove coins.
Note: LionCoins on a member cannot be negative.
Example:
{prefix}set_coins {ctx.author.mention} 100
{prefix}set_coins {ctx.author.mention} -100
"""
# Extract target and amount
# Handle a slightly more flexible input than stated
splits = ctx.args.split()
digits = [isNumber(split) for split in splits[:2]]
mentions = ctx.msg.mentions
if len(splits) < 2 or not any(digits) or not (all(digits) or mentions):
return await _send_usage(ctx)
if all(digits):
# Both are digits, hopefully one is a member id, and one is an amount.
target, amount = ctx.guild.get_member(int(splits[0])), int(splits[1])
if not target:
amount, target = int(splits[0]), ctx.guild.get_member(int(splits[1]))
if not target:
return await _send_usage(ctx)
elif digits[0]:
amount, target = int(splits[0]), mentions[0]
elif digits[1]:
target, amount = mentions[0], int(splits[1])
# Fetch the associated lion
target_lion = Lion.fetch(ctx.guild.id, target.id)
# Check sanity conditions
if target == ctx.client.user:
return await ctx.embed_reply("Thanks, but Ari looks after all my needs!")
if target.bot:
return await ctx.embed_reply("We are still waiting for {} to open an account.".format(target.mention))
# Finally, send the amount and the ack message
# Postgres `coins` column is `integer`, sanity check postgres int limits - which are smalled than python int range
target_coins_to_set = target_lion.coins + amount
if target_coins_to_set >= 0 and target_coins_to_set <= POSTGRES_INT_MAX:
target_lion.addCoins(amount)
elif target_coins_to_set < 0:
target_coins_to_set = -target_lion.coins # Coins cannot go -ve, cap to 0
target_lion.addCoins(target_coins_to_set)
target_coins_to_set = 0
else:
return await ctx.embed_reply("Member coins cannot be more than {}".format(POSTGRES_INT_MAX))
embed = discord.Embed(
title="Funds Set",
description="You have set LionCoins on {} to **{}**!".format(target.mention,target_coins_to_set),
colour=discord.Colour.orange(),
timestamp=datetime.datetime.utcnow()
).set_footer(text=str(ctx.author), icon_url=ctx.author.avatar_url)
await ctx.reply(embed=embed, reference=ctx.msg)
GuildSettings(ctx.guild.id).event_log.log(
"{} set {}'s LionCoins to`{}`.".format(
ctx.author.mention,
target.mention,
target_coins_to_set
),
title="Funds Set"
)
def isNumber(var):
try:
return isinstance(int(var), int)
except:
return False
async def _send_usage(ctx):
return await ctx.error_reply(
"**Usage:** `{prefix}set_coins <mention> <amount>`\n"
"**Example:**\n"
" {prefix}set_coins {ctx.author.mention} 100\n"
" {prefix}set_coins {ctx.author.mention} -100".format(
prefix=ctx.best_prefix,
ctx=ctx
)
)

View File

@@ -1,3 +1,4 @@
import difflib
import discord import discord
from cmdClient.lib import SafeCancellation from cmdClient.lib import SafeCancellation
@@ -121,9 +122,15 @@ async def cmd_config(ctx, flags):
name = parts[0] name = parts[0]
setting = setting_displaynames.get(name.lower(), None) setting = setting_displaynames.get(name.lower(), None)
if setting is None: if setting is None:
matches = difflib.get_close_matches(name, setting_displaynames.keys(), n=2)
match = "`{}`".format('` or `'.join(matches)) if matches else None
return await ctx.error_reply( return await ctx.error_reply(
"Server setting `{}` doesn't exist! Use `{}config` to see all server settings".format( "Couldn't find a setting called `{}`!\n"
name, ctx.best_prefix "{}"
"Use `{}config info` to see all the server settings.".format(
name,
"Maybe you meant {}?\n".format(match) if match else "",
ctx.best_prefix
) )
) )

View File

@@ -485,7 +485,7 @@ async def cmd_reactionroles(ctx, flags):
await ctx.error_reply( await ctx.error_reply(
"The provided channel no longer exists!" "The provided channel no longer exists!"
) )
elif channel.type != discord.ChannelType.text: elif not isinstance(channel, discord.TextChannel):
await ctx.error_reply( await ctx.error_reply(
"The provided channel is not a text channel!" "The provided channel is not a text channel!"
) )
@@ -821,8 +821,8 @@ async def cmd_reactionroles(ctx, flags):
setting = await setting_class.parse(target.messageid, ctx, flags[flag]) setting = await setting_class.parse(target.messageid, ctx, flags[flag])
except UserInputError as e: except UserInputError as e:
return await ctx.error_reply( return await ctx.error_reply(
title="Couldn't save settings!", "{} {}\nNo settings were modified.".format(cross, e.msg),
description="{} {}\nNo settings were modified.".format(cross, e.msg) title="Couldn't save settings!"
) )
else: else:
update_lines.append( update_lines.append(
@@ -861,8 +861,8 @@ async def cmd_reactionroles(ctx, flags):
setting = await setting_class.parse(reaction.reactionid, ctx, flags[flag]) setting = await setting_class.parse(reaction.reactionid, ctx, flags[flag])
except UserInputError as e: except UserInputError as e:
return await ctx.error_reply( return await ctx.error_reply(
"{} {}\nNo reaction roles were modified.".format(cross, e.msg),
title="Couldn't save reaction role settings!", title="Couldn't save reaction role settings!",
description="{} {}\nNo reaction roles were modified.".format(cross, e.msg)
) )
else: else:
update_lines.append( update_lines.append(

View File

@@ -199,6 +199,7 @@ class price(setting_types.Integer, ReactionSetting):
) )
accepts = "An integer number of coins. Use `0` to make the role free, or `None` to use the message default." accepts = "An integer number of coins. Use `0` to make the role free, or `None` to use the message default."
_max = 2 ** 20
@property @property
def default(self): def default(self):

View File

@@ -12,6 +12,7 @@ from discord import PartialEmoji
from meta import client from meta import client
from core import Lion from core import Lion
from data import Row from data import Row
from data.conditions import THIS_SHARD
from utils.lib import utc_now from utils.lib import utc_now
from settings import GuildSettings from settings import GuildSettings
@@ -272,7 +273,7 @@ class ReactionRoleMessage:
# Fetch the number of applicable roles the user has # Fetch the number of applicable roles the user has
roleids = set(reaction.data.roleid for reaction in self.reactions) roleids = set(reaction.data.roleid for reaction in self.reactions)
member_roleids = set(role.id for role in member.roles) member_roleids = set(role.id for role in member.roles)
if len(roleids.intersection(member_roleids)) > maximum: if len(roleids.intersection(member_roleids)) >= maximum:
# Notify the user # Notify the user
embed = discord.Embed( embed = discord.Embed(
title="Maximum group roles reached!", title="Maximum group roles reached!",
@@ -584,5 +585,5 @@ def load_reaction_roles(client):
""" """
Load the ReactionRoleMessages. Load the ReactionRoleMessages.
""" """
rows = reaction_role_messages.fetch_rows_where() rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD)
ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows} ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows}

View File

@@ -6,6 +6,7 @@ import datetime
import discord import discord
from meta import client from meta import client
from data.conditions import THIS_SHARD
from settings import GuildSettings from settings import GuildSettings
from utils.lib import FieldEnum, strfdelta, utc_now from utils.lib import FieldEnum, strfdelta, utc_now
@@ -283,7 +284,8 @@ class Ticket:
# Get all expiring tickets # Get all expiring tickets
expiring_rows = data.tickets.select_where( expiring_rows = data.tickets.select_where(
ticket_state=TicketState.EXPIRING ticket_state=TicketState.EXPIRING,
guildid=THIS_SHARD
) )
# Create new expiry tasks # Create new expiry tasks

View File

@@ -56,10 +56,10 @@ class video_channels(settings.ChannelList, settings.ListData, settings.Setting):
if any(channel.members for channel in guild.voice_channels) if any(channel.members for channel in guild.voice_channels)
] ]
if active_guildids: if active_guildids:
cache = {guildid: [] for guildid in active_guildids}
rows = cls._table_interface.select_where( rows = cls._table_interface.select_where(
guildid=active_guildids guildid=active_guildids
) )
cache = defaultdict(list)
for row in rows: for row in rows:
cache[row['guildid']].append(row['channelid']) cache[row['guildid']].append(row['channelid'])
cls._cache.update(cache) cls._cache.update(cache)

View File

@@ -3,6 +3,7 @@ import asyncio
import datetime import datetime
import discord import discord
from meta import sharding
from utils.lib import parse_dur, parse_ranges, multiselect_regex from utils.lib import parse_dur, parse_ranges, multiselect_regex
from .module import module from .module import module
@@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags):
if not rows: if not rows:
return await ctx.reply("You have no reminders to remove!") return await ctx.reply("You have no reminders to remove!")
live = Reminder.fetch(*(row.reminderid for row in rows)) live = [Reminder(row.reminderid) for row in rows]
if not ctx.args: if not ctx.args:
lines = [] lines = []
@@ -209,6 +210,7 @@ async def cmd_remindme(ctx, flags):
) )
# Schedule reminder # Schedule reminder
if sharding.shard_number == 0:
reminder.schedule() reminder.schedule()
# Ack # Ack
@@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags):
if not rows: if not rows:
return await ctx.reply("You have no reminders!") return await ctx.reply("You have no reminders!")
live = Reminder.fetch(*(row.reminderid for row in rows)) live = [Reminder(row.reminderid) for row in rows]
lines = [] lines = []
num_field = len(str(len(live) - 1)) num_field = len(str(len(live) - 1))

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
import datetime import datetime
import logging
import discord import discord
from meta import client from meta import client, sharding
from utils.lib import strfdur from utils.lib import strfdur
from .data import reminders from .data import reminders
@@ -46,7 +47,10 @@ class Reminder:
cls._live_reminders[reminderid].cancel() cls._live_reminders[reminderid].cancel()
# Remove from data # Remove from data
reminders.delete_where(reminderid=reminderids) if reminderids:
return reminders.delete_where(reminderid=reminderids)
else:
return []
@property @property
def data(self): def data(self):
@@ -134,10 +138,16 @@ class Reminder:
""" """
Execute the reminder. Execute the reminder.
""" """
if self.data.userid in client.objects['blacklisted_users']: if not self.data:
# Reminder deleted elsewhere
return
if self.data.userid in client.user_blacklist():
self.delete(self.reminderid) self.delete(self.reminderid)
return return
userid = self.data.userid
# Build the message embed # Build the message embed
embed = discord.Embed( embed = discord.Embed(
title="You asked me to remind you!", title="You asked me to remind you!",
@@ -155,8 +165,26 @@ class Reminder:
) )
) )
# Update the reminder data, and reschedule if required
if self.data.interval:
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval)
rows = reminders.update_where(
{'remind_at': next_time},
reminderid=self.reminderid
)
self.schedule()
else:
rows = self.delete(self.reminderid)
if not rows:
# Reminder deleted elsewhere
return
# Send the message, if possible # Send the message, if possible
user = self.user if not (user := client.get_user(userid)):
try:
user = await client.fetch_user(userid)
except discord.HTTPException:
pass
if user: if user:
try: try:
await user.send(embed=embed) await user.send(embed=embed)
@@ -164,17 +192,32 @@ class Reminder:
# Nothing we can really do here. Maybe tell the user about their reminder next time? # Nothing we can really do here. Maybe tell the user about their reminder next time?
pass pass
# Update the reminder data, and reschedule if required
if self.data.interval: async def reminder_poll(client):
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) """
reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid) One client/shard must continually poll for new or deleted reminders.
self.schedule() """
else: # TODO: Clean this up with database signals or IPC
self.delete(self.reminderid) while True:
await asyncio.sleep(60)
client.log(
"Running new reminder poll.",
context="REMINDERS",
level=logging.DEBUG
)
rids = {row.reminderid for row in reminders.fetch_rows_where()}
to_delete = (rid for rid in Reminder._live_reminders if rid not in rids)
Reminder.delete(*to_delete)
[Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders]
@module.launch_task @module.launch_task
async def schedule_reminders(client): async def schedule_reminders(client):
if sharding.shard_number == 0:
rows = reminders.fetch_rows_where() rows = reminders.fetch_rows_where()
for row in rows: for row in rows:
Reminder(row.reminderid).schedule() Reminder(row.reminderid).schedule()
@@ -182,3 +225,5 @@ async def schedule_reminders(client):
"Scheduled {} reminders.".format(len(rows)), "Scheduled {} reminders.".format(len(rows)),
context="LAUNCH_REMINDERS" context="LAUNCH_REMINDERS"
) )
if sharding.sharded:
asyncio.create_task(reminder_poll(client))

View File

@@ -54,9 +54,13 @@ async def cmd_rent(ctx):
# Extract members to remove # Extract members to remove
current_memberids = set(room.memberids) current_memberids = set(room.memberids)
if ctx.author in ctx.msg.mentions:
return await ctx.error_reply(
"You can't remove yourself from your own room!"
)
to_remove = ( to_remove = (
member for member in ctx.msg.mentions member for member in ctx.msg.mentions
if member.id in current_memberids if member.id in current_memberids and member.id != ctx.author.id
) )
to_remove = list(set(to_remove)) # Remove duplicates to_remove = list(set(to_remove)) # Remove duplicates
@@ -86,7 +90,7 @@ async def cmd_rent(ctx):
current_memberids = set(room.memberids) current_memberids = set(room.memberids)
to_add = ( to_add = (
member for member in ctx.msg.mentions member for member in ctx.msg.mentions
if member.id not in current_memberids and member.id != ctx.author if member.id not in current_memberids and member.id != ctx.author.id
) )
to_add = list(set(to_add)) # Remove duplicates to_add = list(set(to_add)) # Remove duplicates

View File

@@ -5,6 +5,7 @@ import datetime
from cmdClient.lib import SafeCancellation from cmdClient.lib import SafeCancellation
from meta import client from meta import client
from data.conditions import THIS_SHARD
from settings import GuildSettings from settings import GuildSettings
from .data import rented, rented_members from .data import rented, rented_members
@@ -187,14 +188,14 @@ class Room:
except discord.HTTPException: except discord.HTTPException:
pass pass
# Delete the room from data (cascades to member deletion)
self.delete()
guild_settings.event_log.log( guild_settings.event_log.log(
title="Private study room expired!", title="Private study room expired!",
description="<@{}>'s private study room expired.".format(self.data.ownerid) description="<@{}>'s private study room expired.".format(self.data.ownerid)
) )
# Delete the room from data (cascades to member deletion)
self.delete()
async def add_members(self, *members): async def add_members(self, *members):
guild_settings = GuildSettings(self.data.guildid) guild_settings = GuildSettings(self.data.guildid)
@@ -276,7 +277,7 @@ class Room:
@module.launch_task @module.launch_task
async def load_rented_rooms(client): async def load_rented_rooms(client):
rows = rented.fetch_rows_where() rows = rented.fetch_rows_where(guildid=THIS_SHARD)
for row in rows: for row in rows:
Room(row.channelid).schedule() Room(row.channelid).schedule()
client.log( client.log(

View File

@@ -0,0 +1,7 @@
from .module import module
from . import data
from . import profile
from . import setprofile
from . import top_cmd
from . import goals

39
bot/modules/stats/data.py Normal file
View File

@@ -0,0 +1,39 @@
from cachetools import TTLCache
from data import Table, RowTable
profile_tags = Table('member_profile_tags', attach_as='profile_tags')
@profile_tags.save_query
def get_tags_for(guildid, userid):
rows = profile_tags.select_where(
guildid=guildid, userid=userid,
_extra="ORDER BY tagid ASC"
)
return [row['tag'] for row in rows]
weekly_goals = RowTable(
'member_weekly_goals',
('guildid', 'userid', 'weekid', 'study_goal', 'task_goal'),
('guildid', 'userid', 'weekid'),
cache=TTLCache(5000, 60 * 60 * 24),
attach_as='weekly_goals'
)
# NOTE: Not using a RowTable here since these will almost always be mass-selected
weekly_tasks = Table('member_weekly_goal_tasks')
monthly_goals = RowTable(
'member_monthly_goals',
('guildid', 'userid', 'monthid', 'study_goal', 'task_goal'),
('guildid', 'userid', 'monthid'),
cache=TTLCache(5000, 60 * 60 * 24),
attach_as='monthly_goals'
)
monthly_tasks = Table('member_monthly_goal_tasks')

332
bot/modules/stats/goals.py Normal file
View File

@@ -0,0 +1,332 @@
"""
Weekly and Monthly goal display and edit interface.
"""
from enum import Enum
import discord
from cmdClient.checks import in_guild
from cmdClient.lib import SafeCancellation
from utils.lib import parse_ranges
from .module import module
from .data import weekly_goals, weekly_tasks, monthly_goals, monthly_tasks
MAX_LENGTH = 200
MAX_TASKS = 10
class GoalType(Enum):
WEEKLY = 0
MONTHLY = 1
def index_range_parser(userstr, max):
try:
indexes = parse_ranges(userstr)
except SafeCancellation:
raise SafeCancellation(
"Couldn't parse the provided task ids! "
"Please list the task numbers or ranges separated by a comma, e.g. `0, 2-4`."
) from None
return [index for index in indexes if index <= max]
@module.cmd(
"weeklygoals",
group="Statistics",
desc="Set your weekly goals and view your progress!",
aliases=('weeklygoal',),
flags=('study=', 'tasks=')
)
@in_guild()
async def cmd_weeklygoals(ctx, flags):
"""
Usage``:
{prefix}weeklygoals [--study <hours>] [--tasks <number>]
{prefix}weeklygoals add <task>
{prefix}weeklygoals edit <taskid> <new task>
{prefix}weeklygoals check <taskids>
{prefix}weeklygoals remove <taskids>
Description:
Set yourself up to `10` goals for this week and keep yourself accountable!
Use `add/edit/check/remove` to edit your goals, similarly to `{prefix}todo`.
You can also add multiple tasks at once by writing them on multiple lines.
You can also track your progress towards a number of hours studied with `--study`, \
and aim for a number of tasks completed with `--tasks`.
Run the command with no arguments or check your profile to see your progress!
Examples``:
{prefix}weeklygoals add Read chapters 1 to 10.
{prefix}weeklygoals check 1
{prefix}weeklygoals --study 48h --tasks 60
"""
await goals_command(ctx, flags, GoalType.WEEKLY)
@module.cmd(
"monthlygoals",
group="Statistics",
desc="Set your monthly goals and view your progress!",
aliases=('monthlygoal',),
flags=('study=', 'tasks=')
)
@in_guild()
async def cmd_monthlygoals(ctx, flags):
"""
Usage``:
{prefix}monthlygoals [--study <hours>] [--tasks <number>]
{prefix}monthlygoals add <task>
{prefix}monthlygoals edit <taskid> <new task>
{prefix}monthlygoals check <taskids>
{prefix}monthlygoals uncheck <taskids>
{prefix}monthlygoals remove <taskids>
Description:
Set yourself up to `10` goals for this month and keep yourself accountable!
Use `add/edit/check/remove` to edit your goals, similarly to `{prefix}todo`.
You can also add multiple tasks at once by writing them on multiple lines.
You can also track your progress towards a number of hours studied with `--study`, \
and aim for a number of tasks completed with `--tasks`.
Run the command with no arguments or check your profile to see your progress!
Examples``:
{prefix}monthlygoals add Read chapters 1 to 10.
{prefix}monthlygoals check 1
{prefix}monthlygoals --study 180h --tasks 60
"""
await goals_command(ctx, flags, GoalType.MONTHLY)
async def goals_command(ctx, flags, goal_type):
prefix = ctx.best_prefix
if goal_type == GoalType.WEEKLY:
name = 'week'
goal_table = weekly_goals
task_table = weekly_tasks
rowkey = 'weekid'
rowid = ctx.alion.week_timestamp
tasklist = task_table.select_where(
guildid=ctx.guild.id,
userid=ctx.author.id,
weekid=rowid,
_extra="ORDER BY taskid ASC"
)
max_time = 7 * 16
else:
name = 'month'
goal_table = monthly_goals
task_table = monthly_tasks
rowid = ctx.alion.month_timestamp
rowkey = 'monthid'
tasklist = task_table.select_where(
guildid=ctx.guild.id,
userid=ctx.author.id,
monthid=rowid,
_extra="ORDER BY taskid ASC"
)
max_time = 31 * 16
# We ensured the `lion` existed with `ctx.alion` above
# This also ensures a new tasklist can reference the period member goal key
# TODO: Should creation copy the previous existing week?
goal_row = goal_table.fetch_or_create((ctx.guild.id, ctx.author.id, rowid))
if flags['study']:
# Set study hour goal
time = flags['study'].lower().strip('h ')
if not time or not time.isdigit():
return await ctx.error_reply(
f"Please provide your {name}ly study goal in hours!\n"
f"For example, `{prefix}{ctx.alias} --study 48h`"
)
hours = int(time)
if hours > max_time:
return await ctx.error_reply(
"You can't set your goal this high! Please rest and keep a healthy lifestyle."
)
goal_row.study_goal = hours
if flags['tasks']:
# Set tasks completed goal
count = flags['tasks']
if not count or not count.isdigit():
return await ctx.error_reply(
f"Please provide the number of tasks you want to complete this {name}!\n"
f"For example, `{prefix}{ctx.alias} --tasks 300`"
)
if int(count) > 2048:
return await ctx.error_reply(
"Your task goal is too high!"
)
goal_row.task_goal = int(count)
if ctx.args:
# If there are arguments, assume task/goal management
# Extract the command if it exists, assume add operation if it doesn't
splits = ctx.args.split(maxsplit=1)
cmd = splits[0].lower().strip()
args = splits[1].strip() if len(splits) > 1 else ''
if cmd in ('check', 'done', 'complete'):
if not args:
# Show subcommand usage
return await ctx.error_reply(
f"**Usage:**`{prefix}{ctx.alias} check <taskids>`\n"
f"**Example:**`{prefix}{ctx.alias} check 0, 2-4`"
)
if (indexes := index_range_parser(args, len(tasklist) - 1)):
# Check the given indexes
# If there are no valid indexes given, just do nothing and fall out to showing the goals
task_table.update_where(
{'completed': True},
taskid=[tasklist[index]['taskid'] for index in indexes]
)
elif cmd in ('uncheck', 'undone', 'uncomplete'):
if not args:
# Show subcommand usage
return await ctx.error_reply(
f"**Usage:**`{prefix}{ctx.alias} uncheck <taskids>`\n"
f"**Example:**`{prefix}{ctx.alias} uncheck 0, 2-4`"
)
if (indexes := index_range_parser(args, len(tasklist) - 1)):
# Check the given indexes
# If there are no valid indexes given, just do nothing and fall out to showing the goals
task_table.update_where(
{'completed': False},
taskid=[tasklist[index]['taskid'] for index in indexes]
)
elif cmd in ('remove', 'delete', '-', 'rm'):
if not args:
# Show subcommand usage
return await ctx.error_reply(
f"**Usage:**`{prefix}{ctx.alias} remove <taskids>`\n"
f"**Example:**`{prefix}{ctx.alias} remove 0, 2-4`"
)
if (indexes := index_range_parser(args, len(tasklist) - 1)):
# Delete the given indexes
# If there are no valid indexes given, just do nothing and fall out to showing the goals
task_table.delete_where(
taskid=[tasklist[index]['taskid'] for index in indexes]
)
elif cmd == 'edit':
if not args or len(splits := args.split(maxsplit=1)) < 2 or not splits[0].isdigit():
# Show subcommand usage
return await ctx.error_reply(
f"**Usage:**`{prefix}{ctx.alias} edit <taskid> <edited task>`\n"
f"**Example:**`{prefix}{ctx.alias} edit 2 Fix the scond task`"
)
index = int(splits[0])
new_content = splits[1].strip()
if index >= len(tasklist):
return await ctx.error_reply(
f"Task `{index}` doesn't exist to edit!"
)
if len(new_content) > MAX_LENGTH:
return await ctx.error_reply(
f"Please keep your goals under `{MAX_LENGTH}` characters long."
)
# Passed all checks, edit task
task_table.update_where(
{'content': new_content},
taskid=tasklist[index]['taskid']
)
else:
# Extract the tasks to add
if cmd in ('add', '+'):
if not args:
# Show subcommand usage
return await ctx.error_reply(
f"**Usage:**`{prefix}{ctx.alias} [add] <new task>`\n"
f"**Example:**`{prefix}{ctx.alias} add Read the Studylion help pages.`"
)
else:
args = ctx.args
tasks = args.splitlines()
# Check count
if len(tasklist) + len(tasks) > MAX_TASKS:
return await ctx.error_reply(
f"You can have at most **{MAX_TASKS}** {name}ly goals!"
)
# Check length
if any(len(task) > MAX_LENGTH for task in tasks):
return await ctx.error_reply(
f"Please keep your goals under `{MAX_LENGTH}` characters long."
)
# We passed the checks, add the tasks
to_insert = [
(ctx.guild.id, ctx.author.id, rowid, task)
for task in tasks
]
task_table.insert_many(
*to_insert,
insert_keys=('guildid', 'userid', rowkey, 'content')
)
elif not any((goal_row.study_goal, goal_row.task_goal, tasklist)):
# The user hasn't set any goals for this time period
# Prompt them with information about how to set a goal
embed = discord.Embed(
colour=discord.Colour.orange(),
title=f"**You haven't set any goals for this {name} yet! Try the following:**\n"
)
embed.add_field(
name="Aim for a number of study hours with",
value=f"`{prefix}{ctx.alias} --study 48h`"
)
embed.add_field(
name="Aim for a number of tasks completed with",
value=f"`{prefix}{ctx.alias} --tasks 300`",
inline=False
)
embed.add_field(
name=f"Set up to 10 custom goals for the {name}!",
value=(
f"`{prefix}{ctx.alias} add Write a 200 page thesis.`\n"
f"`{prefix}{ctx.alias} edit 1 Write 2 pages of the 200 page thesis.`\n"
f"`{prefix}{ctx.alias} done 0, 1, 3-4`\n"
f"`{prefix}{ctx.alias} delete 2-4`"
),
inline=False
)
return await ctx.reply(embed=embed)
# Show the goals
if goal_type == GoalType.WEEKLY:
await display_weekly_goals_for(ctx)
else:
await display_monthly_goals_for(ctx)
async def display_weekly_goals_for(ctx):
"""
Display the user's weekly goal summary and progress towards them
TODO: Currently a stub, since the system is overidden by the GUI plugin
"""
# Collect data
lion = ctx.alion
rowid = lion.week_timestamp
goals = weekly_goals.fetch_or_create((ctx.guild.id, ctx.author.id, rowid))
tasklist = weekly_tasks.select_where(
guildid=ctx.guild.id,
userid=ctx.author.id,
weekid=rowid
)
...
async def display_monthly_goals_for(ctx):
...

View File

@@ -0,0 +1,4 @@
from LionModule import LionModule
module = LionModule("Statistics")

View File

@@ -0,0 +1,266 @@
from datetime import datetime, timedelta
import discord
from cmdClient.checks import in_guild
from utils.lib import prop_tabulate, utc_now
from data import tables
from data.conditions import LEQ
from core import Lion
from modules.study.tracking.data import session_history
from .module import module
@module.cmd(
"stats",
group="Statistics",
desc="View your personal server study statistics!",
aliases=('profile',),
allow_before_ready=True
)
@in_guild()
async def cmd_stats(ctx):
"""
Usage``:
{prefix}stats
{prefix}stats <user mention>
Description:
View the study statistics for yourself or the mentioned user.
"""
# Identify the target
if ctx.args:
if not ctx.msg.mentions:
return await ctx.error_reply("Please mention a user to view their statistics!")
target = ctx.msg.mentions[0]
else:
target = ctx.author
# System sync
Lion.sync()
# Fetch the required data
lion = Lion.fetch(ctx.guild.id, target.id)
history = session_history.select_where(
guildid=ctx.guild.id,
userid=target.id,
select_columns=(
"start_time",
"(start_time + duration * interval '1 second') AS end_time"
),
_extra="ORDER BY start_time DESC"
)
# Current economy balance (accounting for current session)
coins = lion.coins
season_time = lion.time
workout_total = lion.data.workout_count
# Leaderboard ranks
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
if target.id in exclude:
time_rank = None
coin_rank = None
else:
time_rank, coin_rank = tables.lions.queries.get_member_rank(ctx.guild.id, target.id, list(exclude or [0]))
# Study time
# First get the all/month/week/day timestamps
day_start = lion.day_start
period_timestamps = (
datetime(1970, 1, 1),
day_start.replace(day=1),
day_start - timedelta(days=day_start.weekday()),
day_start
)
study_times = [0, 0, 0, 0]
for i, timestamp in enumerate(period_timestamps):
study_time = tables.session_history.queries.study_time_since(ctx.guild.id, target.id, timestamp)
if not study_time:
# So we don't make unecessary database calls
break
study_times[i] = study_time
# Streak statistics
streak = 0
current_streak = None
max_streak = 0
day_attended = True if 'sessions' in ctx.client.objects and lion.session else None
date = day_start
daydiff = timedelta(days=1)
periods = [(row['start_time'], row['end_time']) for row in history]
i = 0
while i < len(periods):
row = periods[i]
i += 1
if row[1] > date:
# They attended this day
day_attended = True
continue
elif day_attended is None:
# Didn't attend today, but don't break streak
day_attended = False
date -= daydiff
i -= 1
continue
elif not day_attended:
# Didn't attend the day, streak broken
date -= daydiff
i -= 1
pass
else:
# Attended the day
streak += 1
# Move window to the previous day and try the row again
day_attended = False
prev_date = date
date -= daydiff
i -= 1
# Special case, when the last session started in the previous day
# Then the day is already attended
if i > 1 and date < periods[i-2][0] <= prev_date:
day_attended = True
continue
max_streak = max(max_streak, streak)
if current_streak is None:
current_streak = streak
streak = 0
# Handle loop exit state, i.e. the last streak
if day_attended:
streak += 1
max_streak = max(max_streak, streak)
if current_streak is None:
current_streak = streak
# Accountability stats
accountability = tables.accountability_member_info.select_where(
userid=target.id,
start_at=LEQ(utc_now()),
select_columns=("*", "(duration > 0 OR last_joined_at IS NOT NULL) AS attended"),
_extra="ORDER BY start_at DESC"
)
if len(accountability):
acc_duration = sum(row['duration'] for row in accountability)
acc_attended = sum(row['attended'] for row in accountability)
acc_total = len(accountability)
acc_rate = (acc_attended * 100) / acc_total
else:
acc_duration = 0
acc_rate = 0
# Study League
guild_badges = tables.study_badges.fetch_rows_where(guildid=ctx.guild.id)
if lion.data.last_study_badgeid:
current_badge = tables.study_badges.fetch(lion.data.last_study_badgeid)
else:
current_badge = None
next_badge = min(
(badge for badge in guild_badges
if badge.required_time > (current_badge.required_time if current_badge else 0)),
key=lambda badge: badge.required_time,
default=None
)
# We have all the data
# Now start building the embed
embed = discord.Embed(
colour=discord.Colour.orange(),
title="Study Profile for {}".format(str(target))
)
embed.set_thumbnail(url=target.avatar_url)
# Add studying since if they have studied
if history:
embed.set_footer(text="Studying Since")
embed.timestamp = history[-1]['start_time']
# Set the description based on season time and server rank
if season_time:
time_str = "**{}:{:02}**".format(
season_time // 3600,
(season_time // 60) % 60
)
if time_rank is None:
rank_str = None
elif time_rank == 1:
rank_str = "1st"
elif time_rank == 2:
rank_str = "2nd"
elif time_rank == 3:
rank_str = "3rd"
else:
rank_str = "{}th".format(time_rank)
embed.description = "{} has studied for **{}**{}{}".format(
target.mention,
time_str,
" this season" if study_times[0] - season_time > 60 else "",
", and is ranked **{}** in the server!".format(rank_str) if rank_str else "."
)
else:
embed.description = "{} hasn't studied in this server yet!".format(target.mention)
# Build the stats table
stats = {}
stats['Coins Earned'] = "**{}** LC".format(
coins,
# "Rank `{}`".format(coin_rank) if coins and coin_rank else "Unranked"
)
if workout_total:
stats['Workouts'] = "**{}** sessions".format(workout_total)
if acc_duration:
stats['Accountability'] = "**{}** hours (`{:.0f}%` attended)".format(
acc_duration // 3600,
acc_rate
)
stats['Study Streak'] = "**{}** days{}".format(
current_streak,
" (longest **{}** days)".format(max_streak) if max_streak else ''
)
stats_table = prop_tabulate(*zip(*stats.items()))
# Build the time table
time_table = prop_tabulate(
('Daily', 'Weekly', 'Monthly', 'All Time'),
["{:02}:{:02}".format(t // 3600, (t // 60) % 60) for t in reversed(study_times)]
)
# Populate the embed
embed.add_field(name="Study Time", value=time_table)
embed.add_field(name="Statistics", value=stats_table)
# Add the study league field
if current_badge or next_badge:
current_str = (
"You are currently in <@&{}>!".format(current_badge.roleid) if current_badge else "No league yet!"
)
if next_badge:
needed = max(next_badge.required_time - season_time, 0)
next_str = "Study for **{:02}:{:02}** more to achieve <@&{}>.".format(
needed // 3600,
(needed // 60) % 60,
next_badge.roleid
)
else:
next_str = "You have reached the highest league! Congratulations!"
embed.add_field(
name="Study League",
value="{}\n{}".format(current_str, next_str),
inline=False
)
await ctx.reply(embed=embed)

View File

@@ -0,0 +1,225 @@
"""
Provides a command to update a member's profile badges.
"""
import string
import discord
from cmdClient.lib import SafeCancellation
from cmdClient.checks import in_guild
from wards import guild_moderator
from .data import profile_tags
from .module import module
MAX_TAGS = 10
MAX_LENGTH = 30
@module.cmd(
"setprofile",
group="Personal Settings",
desc="Set or update your study profile tags.",
aliases=('editprofile', 'mytags'),
flags=('clear', 'for')
)
@in_guild()
async def cmd_setprofile(ctx, flags):
"""
Usage``:
{prefix}setprofile <tag>, <tag>, <tag>, ...
{prefix}setprofile <id> <new tag>
{prefix}setprofile --clear [--for @user]
Description:
Set or update the tags appearing in your study server profile.
Moderators can clear a user's tags with `--clear --for @user`.
Examples``:
{prefix}setprofile Mathematics, Bioloyg, Medicine, Undergraduate, Europe
{prefix}setprofile 2 Biology
{prefix}setprofile --clear
"""
if flags['clear']:
if flags['for']:
# Moderator-clearing a user's tags
# First check moderator permissions
if not await guild_moderator.run(ctx):
return await ctx.error_reply(
"You need to be a server moderator to use this!"
)
# Check input and extract users to clear for
if not (users := ctx.msg.mentions):
# Show moderator usage
return await ctx.error_reply(
f"**Usage:** `{ctx.best_prefix}setprofile --clear --for @user`\n"
f"**Example:** {ctx.best_prefix}setprofile --clear --for {ctx.author.mention}"
)
# Clear the tags
profile_tags.delete_where(
guildid=ctx.guild.id,
userid=[user.id for user in users]
)
# Ack the moderator
await ctx.embed_reply(
"Profile tags cleared!"
)
else:
# The author wants to clear their own tags
# First delete the tags, save the rows for reporting
rows = profile_tags.delete_where(
guildid=ctx.guild.id,
userid=ctx.author.id
)
# Ack the user
if not rows:
await ctx.embed_reply(
"You don't have any profile tags to clear!"
)
else:
embed = discord.Embed(
colour=discord.Colour.green(),
description="Successfully cleared your profile!"
)
embed.add_field(
name="Removed tags",
value='\n'.join(row['tag'].upper() for row in rows)
)
await ctx.reply(embed=embed)
elif ctx.args:
if len(splits := ctx.args.split(maxsplit=1)) > 1 and splits[0].isdigit():
# Assume we are editing the provided id
tagid = int(splits[0])
if tagid > MAX_TAGS:
return await ctx.error_reply(
f"Sorry, you can have a maximum of `{MAX_TAGS}` tags!"
)
# Retrieve the user's current taglist
rows = profile_tags.select_where(
guildid=ctx.guild.id,
userid=ctx.author.id,
_extra="ORDER BY tagid ASC"
)
# Parse and validate provided new content
content = splits[1].strip().upper()
validate_tag(content)
if tagid > len(rows):
# Trying to edit a tag that doesn't exist yet
# Just create it instead
profile_tags.insert(
guildid=ctx.guild.id,
userid=ctx.author.id,
tag=content
)
# Ack user
await ctx.reply(
embed=discord.Embed(title="Tag created!", colour=discord.Colour.green())
)
else:
# Get the row id to update
to_edit = rows[tagid - 1]['tagid']
# Update the tag
profile_tags.update_where(
{'tag': content},
tagid=to_edit
)
# Ack user
embed = discord.Embed(
colour=discord.Colour.green(),
title="Tag updated!"
)
await ctx.reply(embed=embed)
else:
# Assume the arguments are a comma separated list of badges
# Parse and validate
to_add = [split.strip().upper() for line in ctx.args.splitlines() for split in line.split(',')]
to_add = [split.replace('<3', '❤️') for split in to_add if split]
if not to_add:
return await ctx.error_reply("No valid tags given, nothing to do!")
validate_tag(*to_add)
if len(to_add) > MAX_TAGS:
return await ctx.error_reply(f"You can have a maximum of {MAX_TAGS} tags!")
# Remove the existing badges
deleted_rows = profile_tags.delete_where(
guildid=ctx.guild.id,
userid=ctx.author.id
)
# Insert the new tags
profile_tags.insert_many(
*((ctx.guild.id, ctx.author.id, tag) for tag in to_add),
insert_keys=('guildid', 'userid', 'tag')
)
# Ack with user
embed = discord.Embed(
colour=discord.Colour.green(),
title="Profile tags updated!"
)
embed.add_field(
name="New tags",
value='\n'.join(to_add)
)
if deleted_rows:
embed.add_field(
name="Replaced tags",
value='\n'.join(row['tag'].upper() for row in deleted_rows),
inline=False
)
if len(to_add) == 1:
embed.set_footer(
text=f"TIP: Add multiple tags with {ctx.best_prefix}setprofile tag1, tag2, ..."
)
await ctx.reply(embed=embed)
else:
# No input was provided
# Show usage and exit
embed = discord.Embed(
colour=discord.Colour.red(),
description=(
"Edit your study profile "
"tags so other people can see what you do!"
)
)
embed.add_field(
name="Usage",
value=(
f"`{ctx.best_prefix}setprofile <tag>, <tag>, <tag>, ...`\n"
f"`{ctx.best_prefix}setprofile <id> <new tag>`"
)
)
embed.add_field(
name="Examples",
value=(
f"`{ctx.best_prefix}setprofile Mathematics, Bioloyg, Medicine, Undergraduate, Europe`\n"
f"`{ctx.best_prefix}setprofile 2 Biology`"
),
inline=False
)
await ctx.reply(embed=embed)
def validate_tag(*content):
for content in content:
if not set(content.replace('❤️', '')).issubset(string.printable):
raise SafeCancellation(
f"Invalid tag `{content}`!\n"
"Tags may only contain alphanumeric and punctuation characters."
)
if len(content) > MAX_LENGTH:
raise SafeCancellation(
f"Provided tag is too long! Please keep your tags shorter than {MAX_LENGTH} characters."
)

View File

@@ -38,27 +38,20 @@ async def cmd_top(ctx):
) )
top100 = (ctx.args == "100" or ctx.alias == "top100") top100 = (ctx.args == "100" or ctx.alias == "top100")
# Flush any pending coin transactions
Lion.sync()
# Fetch the leaderboard # Fetch the leaderboard
exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members) exclude = set(m.id for m in ctx.guild_settings.unranked_roles.members)
exclude.update(ctx.client.objects['blacklisted_users']) exclude.update(ctx.client.user_blacklist())
exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id]) exclude.update(ctx.client.objects['ignored_members'][ctx.guild.id])
args = {
'guildid': ctx.guild.id,
'select_columns': ('userid', 'total_tracked_time::INTEGER'),
'_extra': "AND total_tracked_time > 0 ORDER BY total_tracked_time DESC " + ("LIMIT 100" if top100 else "")
}
if exclude: if exclude:
user_data = tables.lions.select_where( args['userid'] = data.NOT(list(exclude))
guildid=ctx.guild.id,
userid=data.NOT(list(exclude)), user_data = tables.members_totals.select_where(**args)
select_columns=('userid', 'tracked_time'),
_extra="AND tracked_time > 0 ORDER BY tracked_time DESC " + ("LIMIT 100" if top100 else "")
)
else:
user_data = tables.lions.select_where(
guildid=ctx.guild.id,
select_columns=('userid', 'tracked_time'),
_extra="AND tracked_time > 0 ORDER BY tracked_time DESC " + ("LIMIT 100" if top100 else "")
)
# Quit early if the leaderboard is empty # Quit early if the leaderboard is empty
if not user_data: if not user_data:

View File

@@ -3,6 +3,3 @@ from .module import module
from . import badges from . import badges
from . import timers from . import timers
from . import tracking from . import tracking
from . import top_cmd
from . import stats_cmd

View File

@@ -6,9 +6,8 @@ import contextlib
import discord import discord
from meta import client from meta import client, sharding
from data.conditions import GEQ from data.conditions import GEQ, THIS_SHARD
from core import Lion
from core.data import lions from core.data import lions
from utils.lib import strfdur from utils.lib import strfdur
from settings import GuildSettings from settings import GuildSettings
@@ -55,11 +54,16 @@ async def update_study_badges(full=False):
# Retrieve member rows with out of date study badges # Retrieve member rows with out of date study badges
if not full and client.appdata.last_study_badge_scan is not None: if not full and client.appdata.last_study_badge_scan is not None:
# TODO: _extra here is a hack to cover for inflexible conditionals
update_rows = new_study_badges.select_where( update_rows = new_study_badges.select_where(
_timestamp=GEQ(client.appdata.last_study_badge_scan or 0) guildid=THIS_SHARD,
_timestamp=GEQ(client.appdata.last_study_badge_scan or 0),
_extra="OR session_start IS NOT NULL AND (guildid >> 22) %% {} = {}".format(
sharding.shard_count, sharding.shard_number
)
) )
else: else:
update_rows = new_study_badges.select_where() update_rows = new_study_badges.select_where(guildid=THIS_SHARD)
if not update_rows: if not update_rows:
client.appdata.last_study_badge_scan = datetime.datetime.utcnow() client.appdata.last_study_badge_scan = datetime.datetime.utcnow()
@@ -287,7 +291,6 @@ async def study_badge_tracker():
""" """
while True: while True:
try: try:
Lion.sync()
await update_study_badges() await update_study_badges()
except Exception: except Exception:
# Unknown exception. Catch it so the loop doesn't die. # Unknown exception. Catch it so the loop doesn't die.
@@ -304,11 +307,10 @@ async def study_badge_tracker():
await asyncio.sleep(60) await asyncio.sleep(60)
async def _update_member_studybadge(member): async def update_member_studybadge(member):
""" """
Checks and (if required) updates the study badge for a single member. Checks and (if required) updates the study badge for a single member.
""" """
Lion.fetch(member.guild.id, member.id).flush()
update_rows = new_study_badges.select_where( update_rows = new_study_badges.select_where(
guildid=member.guild.id, guildid=member.guild.id,
userid=member.id userid=member.id
@@ -332,16 +334,6 @@ async def _update_member_studybadge(member):
await _update_guild_badges(member.guild, update_rows) await _update_guild_badges(member.guild, update_rows)
@client.add_after_event("voice_state_update")
async def voice_studybadge_updater(client, member, before, after):
if not client.is_ready():
# The poll loop will pick it up
return
if before.channel and not after.channel:
await _update_member_studybadge(member)
@module.launch_task @module.launch_task
async def launch_study_badge_tracker(client): async def launch_study_badge_tracker(client):
asyncio.create_task(study_badge_tracker()) asyncio.create_task(study_badge_tracker())

View File

@@ -1,4 +1,4 @@
from LionModule import LionModule from LionModule import LionModule
module = LionModule("Study_Stats") module = LionModule("Study_Tracking")

View File

@@ -1,83 +0,0 @@
import datetime
import discord
from cmdClient.checks import in_guild
from utils.lib import strfdur
from data import tables
from core import Lion
from .module import module
@module.cmd(
"stats",
group="Statistics",
desc="View a summary of your study statistics!"
)
@in_guild()
async def cmd_stats(ctx):
"""
Usage``:
{prefix}stats
{prefix}stats <user mention>
Description:
View the study statistics for yourself or the mentioned user.
"""
if ctx.args:
if not ctx.msg.mentions:
return await ctx.error_reply("Please mention a user to view their statistics!")
target = ctx.msg.mentions[0]
else:
target = ctx.author
# Collect the required target data
lion = Lion.fetch(ctx.guild.id, target.id)
rank_data = tables.lion_ranks.select_one_where(
userid=target.id,
guildid=ctx.guild.id
)
# Extract and format data
time = strfdur(lion.time)
coins = lion.coins
workouts = lion.data.workout_count
if lion.data.last_study_badgeid:
badge_row = tables.study_badges.fetch(lion.data.last_study_badgeid)
league = "<@&{}>".format(badge_row.roleid)
else:
league = "No league yet!"
time_lb_pos = rank_data['time_rank']
coin_lb_pos = rank_data['coin_rank']
# Build embed
embed = discord.Embed(
colour=discord.Colour.blue(),
timestamp=datetime.datetime.utcnow(),
title="Revision Statistics"
).set_footer(text=str(target), icon_url=target.avatar_url).set_thumbnail(url=target.avatar_url)
embed.add_field(
name="📚 Study Time",
value=time
)
embed.add_field(
name="🦁 Revision League",
value=league
)
embed.add_field(
name="🦁 LionCoins",
value=coins
)
embed.add_field(
name="🏆 Leaderboard Position",
value="Time: {}\n LC: {}".format(time_lb_pos, coin_lb_pos)
)
embed.add_field(
name="💪 Workouts",
value=workouts
)
embed.add_field(
name="📋 Attendence",
value="TBD"
)
await ctx.reply(embed=embed)

View File

@@ -1,7 +1,21 @@
from data import Table, RowTable, tables from data import Table, RowTable, tables
from utils.lib import FieldEnum
untracked_channels = Table('untracked_channels') untracked_channels = Table('untracked_channels')
class SessionChannelType(FieldEnum):
"""
The possible session channel types.
"""
# NOTE: "None" stands for Unknown, and the STANDARD description should be replaced with the channel name
STANDARD = 'STANDARD', "Standard"
ACCOUNTABILITY = 'ACCOUNTABILITY', "Accountability Room"
RENTED = 'RENTED', "Private Room"
EXTERNAL = 'EXTERNAL', "Unknown"
session_history = Table('session_history') session_history = Table('session_history')
current_sessions = RowTable( current_sessions = RowTable(
'current_sessions', 'current_sessions',
@@ -30,3 +44,19 @@ def close_study_session(guildid, userid):
current_sessions.row_cache.pop((guildid, userid), None) current_sessions.row_cache.pop((guildid, userid), None)
# Use the function output to update the member cache # Use the function output to update the member cache
tables.lions._make_rows(*rows) tables.lions._make_rows(*rows)
@session_history.save_query
def study_time_since(guildid, userid, timestamp):
"""
Retrieve the total member study time (in seconds) since the given timestamp.
Includes the current session, if it exists.
"""
with session_history.conn as conn:
cursor = conn.cursor()
cursor.callproc('study_time_since', (guildid, userid, timestamp))
rows = cursor.fetchall()
return (rows[0][0] if rows else None) or 0
members_totals = Table('members_totals')

View File

@@ -1,21 +1,44 @@
import asyncio import asyncio
import discord import discord
import logging
import traceback
from typing import Dict
from collections import defaultdict from collections import defaultdict
from utils.lib import utc_now from utils.lib import utc_now
from data import tables
from data.conditions import THIS_SHARD
from core import Lion
from meta import client
from ..module import module from ..module import module
from .data import current_sessions from .data import current_sessions, SessionChannelType
from .settings import untracked_channels, hourly_reward, hourly_live_bonus, max_daily_study from .settings import untracked_channels, hourly_reward, hourly_live_bonus
class Session: class Session:
# TODO: Slots """
sessions = defaultdict(dict) A `Session` describes an ongoing study session by a single guild member.
A member is counted as studying when they are in a tracked voice channel.
This class acts as an opaque interface to the corresponding `sessions` data row.
"""
__slots__ = (
'guildid',
'userid',
'_expiry_task'
)
# Global cache of ongoing sessions
sessions: Dict[int, Dict[int, 'Session']] = defaultdict(dict)
# Global cache of members pending session start (waiting for daily cap reset)
members_pending: Dict[int, Dict[int, asyncio.Task]] = defaultdict(dict)
def __init__(self, guildid, userid): def __init__(self, guildid, userid):
self.guildid = guildid self.guildid = guildid
self.userid = userid self.userid = userid
self.key = (guildid, userid)
self._expiry_task: asyncio.Task = None
@classmethod @classmethod
def get(cls, guildid, userid): def get(cls, guildid, userid):
@@ -36,14 +59,36 @@ class Session:
if userid in cls.sessions[guildid]: if userid in cls.sessions[guildid]:
raise ValueError("A session for this member already exists!") raise ValueError("A session for this member already exists!")
# TODO: Handle daily study cap
# TODO: Calculate channel type # If the user is study capped, schedule the session start for the next day
# TODO: Ensure lion if (lion := Lion.fetch(guildid, userid)).remaining_study_today <= 0:
if pending := cls.members_pending[guildid].pop(userid, None):
pending.cancel()
task = asyncio.create_task(cls._delayed_start(guildid, userid, member, state))
cls.members_pending[guildid][userid] = task
client.log(
"Member (uid:{}) in (gid:{}) is study capped, "
"delaying session start for {} seconds until start of next day.".format(
userid, guildid, lion.remaining_in_day
),
context="SESSION_TRACKER",
level=logging.DEBUG
)
return
# TODO: More reliable channel type determination
if state.channel.id in tables.rented.row_cache:
channel_type = SessionChannelType.RENTED
elif state.channel.category and state.channel.category.id == lion.guild_settings.accountability_category.data:
channel_type = SessionChannelType.ACCOUNTABILITY
else:
channel_type = SessionChannelType.STANDARD
current_sessions.create_row( current_sessions.create_row(
guildid=guildid, guildid=guildid,
userid=userid, userid=userid,
channelid=state.channel.id, channelid=state.channel.id,
channel_type=None, channel_type=channel_type,
start_time=now, start_time=now,
live_start=now if (state.self_video or state.self_stream) else None, live_start=now if (state.self_video or state.self_stream) else None,
stream_start=now if state.self_stream else None, stream_start=now if state.self_stream else None,
@@ -51,23 +96,127 @@ class Session:
hourly_coins=hourly_reward.get(guildid).value, hourly_coins=hourly_reward.get(guildid).value,
hourly_live_coins=hourly_live_bonus.get(guildid).value hourly_live_coins=hourly_live_bonus.get(guildid).value
) )
session = cls(guildid, userid) session = cls(guildid, userid).activate()
cls.sessions[guildid][userid] = session client.log(
return session "Started session: {}".format(session.data),
context="SESSION_TRACKER",
level=logging.DEBUG,
)
@classmethod
async def _delayed_start(cls, guildid, userid, *args):
delay = Lion.fetch(guildid, userid).remaining_in_day
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
pass
else:
cls.start(*args)
@property
def key(self):
"""
RowTable Session identification key.
"""
return (self.guildid, self.userid)
@property
def lion(self):
"""
The Lion member object associated with this member.
"""
return Lion.fetch(self.guildid, self.userid)
@property @property
def data(self): def data(self):
"""
Row of the `current_sessions` table corresponding to this session.
"""
return current_sessions.fetch(self.key) return current_sessions.fetch(self.key)
@property
def duration(self):
"""
Current duration of the session.
"""
return (utc_now() - self.data.start_time).total_seconds()
@property
def coins_earned(self):
"""
Number of coins earned so far.
"""
data = self.data
coins = self.duration * data.hourly_coins
coins += data.live_duration * data.hourly_live_coins
if data.live_start:
coins += (utc_now() - data.live_start).total_seconds() * data.hourly_live_coins
return coins // 3600
def activate(self):
"""
Activate the study session.
This adds the session to the studying members cache,
and schedules the session expiry, based on the daily study cap.
"""
# Add to the active cache
self.sessions[self.guildid][self.userid] = self
# Schedule the session expiry
self.schedule_expiry()
# Return self for easy chaining
return self
def schedule_expiry(self):
"""
Schedule session termination when the user reaches the maximum daily study time.
"""
asyncio.create_task(self._schedule_expiry())
async def _schedule_expiry(self):
# Cancel any existing expiry
if self._expiry_task and not self._expiry_task.done():
self._expiry_task.cancel()
# Wait for the maximum session length
try:
self._expiry_task = await asyncio.sleep(self.lion.remaining_study_today)
except asyncio.CancelledError:
pass
else:
if self.lion.remaining_study_today <= 0:
# End the session
# Note that the user will not automatically start a new session when the day starts
# TODO: Notify user? Disconnect them?
client.log(
"Session for (uid:{}) in (gid:{}) reached daily guild study cap.\n{}".format(
self.userid, self.guildid, self.data
),
context="SESSION_TRACKER"
)
self.finish()
else:
# It's possible the expiry time was pushed forwards while waiting
# If so, reschedule
self.schedule_expiry()
def finish(self): def finish(self):
""" """
Close the study session. Close the study session.
""" """
self.sessions[self.guildid].pop(self.userid, None)
# Note that save_live_status doesn't need to be called here # Note that save_live_status doesn't need to be called here
# The database saving procedure will account for the values. # The database saving procedure will account for the values.
current_sessions.queries.close_study_session(*self.key) current_sessions.queries.close_study_session(*self.key)
# Remove session from active cache
self.sessions[self.guildid].pop(self.userid, None)
# Cancel any existing expiry task
if self._expiry_task and not self._expiry_task.done():
self._expiry_task.cancel()
def save_live_status(self, state: discord.VoiceState): def save_live_status(self, state: discord.VoiceState):
""" """
Update the saved live status of the member. Update the saved live status of the member.
@@ -101,6 +250,7 @@ async def session_voice_tracker(client, member, before, after):
Voice update event dispatcher for study session tracking. Voice update event dispatcher for study session tracking.
""" """
guild = member.guild guild = member.guild
Lion.fetch(guild.id, member.id)
session = Session.get(guild.id, member.id) session = Session.get(guild.id, member.id)
if before.channel == after.channel: if before.channel == after.channel:
@@ -111,12 +261,45 @@ async def session_voice_tracker(client, member, before, after):
else: else:
# Member changed channel # Member changed channel
# End the current session and start a new one, if applicable # End the current session and start a new one, if applicable
# TODO: Max daily study session tasks
if session: if session:
if (scid := session.data.channelid) and (not before.channel or scid != before.channel.id):
client.log(
"The previous voice state for "
"member {member.name} (uid:{member.id}) in {guild.name} (gid:{guild.id}) "
"does not match their current study session!\n"
"Session channel is (cid:{scid}), but the previous channel is {previous}.".format(
member=member,
guild=member.guild,
scid=scid,
previous="{0.name} (cid:{0.id})".format(before.channel) if before.channel else "None"
),
context="SESSION_TRACKER",
level=logging.ERROR
)
client.log(
"Ending study session for {member.name} (uid:{member.id}) "
"in {member.guild.id} (gid:{member.guild.id}) since they left the voice channel.\n{session}".format(
member=member,
session=session.data
),
context="SESSION_TRACKER",
post=False
)
# End the current session # End the current session
session.finish() session.finish()
elif pending := Session.members_pending[guild.id].pop(member.id, None):
client.log(
"Cancelling pending study session for {member.name} (uid:{member.id}) "
"in {member.guild.name} (gid:{member.guild.id}) since they left the voice channel.".format(
member=member
),
context="SESSION_TRACKER",
post=False
)
pending.cancel()
if after.channel: if after.channel:
blacklist = client.objects['blacklisted_users'] blacklist = client.user_blacklist()
guild_blacklist = client.objects['ignored_members'][guild.id] guild_blacklist = client.objects['ignored_members'][guild.id]
untracked = untracked_channels.get(guild.id).data untracked = untracked_channels.get(guild.id).data
start_session = ( start_session = (
@@ -126,7 +309,72 @@ async def session_voice_tracker(client, member, before, after):
) )
if start_session: if start_session:
# Start a new session for the member # Start a new session for the member
Session.start(member, after) client.log(
"Starting a new voice channel study session for {member.name} (uid:{member.id}) "
"in {member.guild.name} (gid:{member.guild.id}).".format(
member=member,
),
context="SESSION_TRACKER",
post=False
)
session = Session.start(member, after)
async def leave_guild_sessions(client, guild):
"""
`guild_leave` hook.
Close all sessions in the guild when we leave.
"""
sessions = list(Session.sessions[guild.id].values())
for session in sessions:
session.finish()
client.log(
"Left {} (gid:{}) and closed {} ongoing study sessions.".format(guild.name, guild.id, len(sessions)),
context="SESSION_TRACKER"
)
async def join_guild_sessions(client, guild):
"""
`guild_join` hook.
Refresh all sessions for the guild when we rejoin.
"""
# Delete existing current sessions, which should have been closed when we left
# It is possible we were removed from the guild during an outage
current_sessions.delete_where(guildid=guild.id)
untracked = untracked_channels.get(guild.id).data
members = [
member
for channel in guild.voice_channels
for member in channel.members
if channel.members and channel.id not in untracked
]
for member in members:
client.log(
"Starting new session for '{}' (uid: {}) in '{}' (cid: {}) of '{}' (gid: {})".format(
member.name,
member.id,
member.voice.channel.name,
member.voice.channel.id,
member.guild.name,
member.guild.id
),
context="SESSION_TRACKER",
level=logging.INFO,
post=False
)
Session.start(member, member.voice)
# Log newly started sessions
client.log(
"Joined {} (gid:{}) and started {} new study sessions from current voice channel members.".format(
guild.name,
guild.id,
len(members)
),
context="SESSION_TRACKER",
)
async def _init_session_tracker(client): async def _init_session_tracker(client):
@@ -135,9 +383,111 @@ async def _init_session_tracker(client):
update them depending on the current voice states, update them depending on the current voice states,
and attach the voice event handler. and attach the voice event handler.
""" """
# Ensure the client caches are ready and guilds are chunked
await client.wait_until_ready() await client.wait_until_ready()
# Pre-cache the untracked channels
await untracked_channels.launch_task(client) await untracked_channels.launch_task(client)
# Log init start and define logging counters
client.log(
"Loading ongoing study sessions.",
context="SESSION_INIT",
level=logging.DEBUG
)
resumed = 0
ended = 0
# Grab all ongoing sessions from data
rows = current_sessions.fetch_rows_where(guildid=THIS_SHARD)
# Iterate through, resume or end as needed
for row in rows:
if (guild := client.get_guild(row.guildid)) is not None and row.channelid is not None:
try:
# Load the Session
session = Session(row.guildid, row.userid)
# Find the channel and member voice state
voice = None
if channel := guild.get_channel(row.channelid):
voice = next((member.voice for member in channel.members if member.id == row.userid), None)
# Resume or end as required
if voice and voice.channel:
client.log(
"Resuming ongoing session: {}".format(row),
context="SESSION_INIT",
level=logging.DEBUG
)
session.activate()
session.save_live_status(voice)
resumed += 1
else:
client.log(
"Ending already completed session: {}".format(row),
context="SESSION_INIT",
level=logging.DEBUG
)
session.finish()
ended += 1
except Exception:
# Fatal error
client.log(
"Fatal error occurred initialising session: {}\n{}".format(row, traceback.format_exc()),
context="SESSION_INIT",
level=logging.CRITICAL
)
module.ready = False
return
# Log resumed sessions
client.log(
"Resumed {} ongoing study sessions, and ended {}.".format(resumed, ended),
context="SESSION_INIT",
level=logging.INFO
)
# Now iterate through members of all tracked voice channels
# Start sessions if they don't already exist
tracked_channels = [
channel
for guild in client.guilds
for channel in guild.voice_channels
if channel.members and channel.id not in untracked_channels.get(guild.id).data
]
new_members = [
member
for channel in tracked_channels
for member in channel.members
if not Session.get(member.guild.id, member.id)
]
for member in new_members:
client.log(
"Starting new session for '{}' (uid: {}) in '{}' (cid: {}) of '{}' (gid: {})".format(
member.name,
member.id,
member.voice.channel.name,
member.voice.channel.id,
member.guild.name,
member.guild.id
),
context="SESSION_INIT",
level=logging.DEBUG
)
Session.start(member, member.voice)
# Log newly started sessions
client.log(
"Started {} new study sessions from current voice channel members.".format(len(new_members)),
context="SESSION_INIT",
level=logging.INFO
)
# Now that we are in a valid initial state, attach the session event handler
client.add_after_event("voice_state_update", session_voice_tracker) client.add_after_event("voice_state_update", session_voice_tracker)
client.add_after_event("guild_remove", leave_guild_sessions)
client.add_after_event("guild_join", join_guild_sessions)
@module.launch_task @module.launch_task

View File

@@ -1,5 +1,3 @@
from collections import defaultdict
import settings import settings
from settings import GuildSettings from settings import GuildSettings
from wards import guild_admin from wards import guild_admin
@@ -52,10 +50,10 @@ class untracked_channels(settings.ChannelList, settings.ListData, settings.Setti
if any(channel.members for channel in guild.voice_channels) if any(channel.members for channel in guild.voice_channels)
] ]
if active_guildids: if active_guildids:
cache = {guildid: [] for guildid in active_guildids}
rows = cls._table_interface.select_where( rows = cls._table_interface.select_where(
guildid=active_guildids guildid=active_guildids
) )
cache = defaultdict(list)
for row in rows: for row in rows:
cache[row['guildid']].append(row['channelid']) cache[row['guildid']].append(row['channelid'])
cls._cache.update(cache) cls._cache.update(cache)
@@ -114,11 +112,30 @@ class hourly_live_bonus(settings.Integer, settings.GuildSetting):
@GuildSettings.attach_setting @GuildSettings.attach_setting
class max_daily_study(settings.Duration, settings.GuildSetting): class daily_study_cap(settings.Duration, settings.GuildSetting):
category = "Study Tracking" category = "Study Tracking"
attr_name = "max_daily_study" attr_name = "daily_study_cap"
_data_column = "max_daily_study" _data_column = "daily_study_cap"
display_name = "max_daily_study" display_name = "daily_study_cap"
desc = "Maximum amount of study time ..." desc = "Maximum amount of recorded study time per member per day."
_default = 16 * 60 * 60
_default_multiplier = 60 * 60
_max = 25 * 60 * 60
long_desc = (
"The maximum amount of study time that can be recorded for a member per day, "
"intended to remove system encouragement for unhealthy or obsessive behaviour.\n"
"The member may study for longer, but their sessions will not be tracked. "
"The start and end of the day are determined by the member's configured timezone."
)
@property
def success_response(self):
# Refresh expiry for all sessions in the guild
[session.schedule_expiry() for session in self.client.objects['sessions'][self.id].values()]
return "The maximum tracked daily study time is now {}.".format(self.formatted)

View File

@@ -47,7 +47,7 @@ def _scan(guild):
members = itertools.chain(*channel_members) members = itertools.chain(*channel_members)
# TODO filter out blacklisted users # TODO filter out blacklisted users
blacklist = client.objects['blacklisted_users'] blacklist = client.user_blacklist()
guild_blacklist = client.objects['ignored_members'][guild.id] guild_blacklist = client.objects['ignored_members'][guild.id]
for member in members: for member in members:

View File

@@ -7,6 +7,8 @@ import discord
from cmdClient.checks import is_owner from cmdClient.checks import is_owner
from cmdClient.lib import ResponseTimedOut from cmdClient.lib import ResponseTimedOut
from meta.sharding import sharded
from .module import module from .module import module
@@ -26,14 +28,14 @@ async def cmd_guildblacklist(ctx, flags):
Description: Description:
View, add, or remove guilds from the blacklist. View, add, or remove guilds from the blacklist.
""" """
blacklist = ctx.client.objects['blacklisted_guilds'] blacklist = ctx.client.guild_blacklist()
if ctx.args: if ctx.args:
# guildid parsing # guildid parsing
items = [item.strip() for item in ctx.args.split(',')] items = [item.strip() for item in ctx.args.split(',')]
if any(not item.isdigit() for item in items): if any(not item.isdigit() for item in items):
return await ctx.error_reply( return await ctx.error_reply(
"Please provide guilds as comma seprated guild ids." "Please provide guilds as comma separated guild ids."
) )
guildids = set(int(item) for item in items) guildids = set(int(item) for item in items)
@@ -80,9 +82,18 @@ async def cmd_guildblacklist(ctx, flags):
insert_keys=('guildid', 'ownerid', 'reason') insert_keys=('guildid', 'ownerid', 'reason')
) )
# Check if we are in any of these guilds # Leave freshly blacklisted guilds, accounting for shards
to_leave = (ctx.client.get_guild(guildid) for guildid in to_add) to_leave = []
to_leave = [guild for guild in to_leave if guild is not None] for guildid in to_add:
guild = ctx.client.get_guild(guildid)
if not guild and sharded:
try:
guild = await ctx.client.fetch_guild(guildid)
except discord.HTTPException:
pass
if guild:
to_leave.append(guild)
for guild in to_leave: for guild in to_leave:
await guild.leave() await guild.leave()
@@ -102,9 +113,8 @@ async def cmd_guildblacklist(ctx, flags):
) )
# Refresh the cached blacklist after modification # Refresh the cached blacklist after modification
ctx.client.objects['blacklisted_guilds'] = set( ctx.client.guild_blacklist.cache_clear()
row['guildid'] for row in ctx.client.data.global_guild_blacklist.select_where() ctx.client.guild_blacklist()
)
else: else:
# Display the current blacklist # Display the current blacklist
# First fetch the full blacklist data # First fetch the full blacklist data
@@ -183,7 +193,7 @@ async def cmd_userblacklist(ctx, flags):
Description: Description:
View, add, or remove users from the blacklist. View, add, or remove users from the blacklist.
""" """
blacklist = ctx.client.objects['blacklisted_users'] blacklist = ctx.client.user_blacklist()
if ctx.args: if ctx.args:
# userid parsing # userid parsing
@@ -245,9 +255,8 @@ async def cmd_userblacklist(ctx, flags):
) )
# Refresh the cached blacklist after modification # Refresh the cached blacklist after modification
ctx.client.objects['blacklisted_users'] = set( ctx.client.user_blacklist.cache_clear()
row['userid'] for row in ctx.client.data.global_user_blacklist.select_where() ctx.client.user_blacklist()
)
else: else:
# Display the current blacklist # Display the current blacklist
# First fetch the full blacklist data # First fetch the full blacklist data

View File

@@ -13,19 +13,13 @@ async def update_status():
# TODO: Make globally configurable and saveable # TODO: Make globally configurable and saveable
global _last_update global _last_update
if time.time() - _last_update < 30: if time.time() - _last_update < 60:
return return
_last_update = time.time() _last_update = time.time()
student_count = sum( student_count, room_count = client.data.current_sessions.select_one_where(
len(ch.members) select_columns=("COUNT(*) AS studying_count", "COUNT(DISTINCT(channelid)) AS channel_count"),
for guild in client.guilds
for ch in guild.voice_channels
)
room_count = sum(
len([vc for vc in guild.voice_channels if vc.members])
for guild in client.guilds
) )
status = "{} students in {} study rooms!".format(student_count, room_count) status = "{} students in {} study rooms!".format(student_count, room_count)

View File

@@ -6,8 +6,9 @@ import asyncio
from cmdClient.lib import SafeCancellation from cmdClient.lib import SafeCancellation
from meta import client from meta import client
from core import Lion from core import Lion
from data import NULL, NOTNULL
from settings import GuildSettings from settings import GuildSettings
from utils.lib import parse_ranges from utils.lib import parse_ranges, utc_now
from . import data from . import data
# from .module import module # from .module import module
@@ -130,12 +131,12 @@ class Tasklist:
""" """
self.tasklist = data.tasklist.fetch_rows_where( self.tasklist = data.tasklist.fetch_rows_where(
userid=self.member.id, userid=self.member.id,
_extra=("AND last_updated_at > timezone('utc', NOW()) - INTERVAL '24h' " deleted_at=NULL,
"ORDER BY created_at ASC, taskid ASC") _extra="ORDER BY created_at ASC, taskid ASC"
) )
self._refreshed_at = datetime.datetime.utcnow() self._refreshed_at = datetime.datetime.utcnow()
def _format_tasklist(self): async def _format_tasklist(self):
""" """
Generates a sequence of pages from the tasklist Generates a sequence of pages from the tasklist
""" """
@@ -144,7 +145,7 @@ class Tasklist:
"{num:>{numlen}}. [{mark}] {content}".format( "{num:>{numlen}}. [{mark}] {content}".format(
num=i, num=i,
numlen=((self.block_size * (i // self.block_size + 1) - 1) // 10) + 1, numlen=((self.block_size * (i // self.block_size + 1) - 1) // 10) + 1,
mark=self.checkmark if task.complete else ' ', mark=self.checkmark if task.completed_at else ' ',
content=task.content content=task.content
) )
for i, task in enumerate(self.tasklist) for i, task in enumerate(self.tasklist)
@@ -159,7 +160,7 @@ class Tasklist:
# Formatting strings and data # Formatting strings and data
page_count = len(task_blocks) or 1 page_count = len(task_blocks) or 1
task_count = len(task_strings) task_count = len(task_strings)
complete_count = len([task for task in self.tasklist if task.complete]) complete_count = len([task for task in self.tasklist if task.completed_at])
if task_count > 0: if task_count > 0:
title = "TODO list ({}/{} complete)".format( title = "TODO list ({}/{} complete)".format(
@@ -176,7 +177,7 @@ class Tasklist:
hint = "Type `add <task>` to start adding tasks! E.g. `add Revise Maths Paper 1`." hint = "Type `add <task>` to start adding tasks! E.g. `add Revise Maths Paper 1`."
task_blocks = [""] # Empty page so we can post task_blocks = [""] # Empty page so we can post
# Create formtted page embeds, adding help if required # Create formatted page embeds, adding help if required
pages = [] pages = []
for i, block in enumerate(task_blocks): for i, block in enumerate(task_blocks):
embed = discord.Embed( embed = discord.Embed(
@@ -205,7 +206,7 @@ class Tasklist:
# Calculate or adjust the current page number # Calculate or adjust the current page number
if self.current_page is None: if self.current_page is None:
# First page with incomplete task, or the first page # First page with incomplete task, or the first page
first_incomplete = next((i for i, task in enumerate(self.tasklist) if not task.complete), 0) first_incomplete = next((i for i, task in enumerate(self.tasklist) if not task.completed_at), 0)
self.current_page = first_incomplete // self.block_size self.current_page = first_incomplete // self.block_size
elif self.current_page >= len(self.pages): elif self.current_page >= len(self.pages):
self.current_page = len(self.pages) - 1 self.current_page = len(self.pages) - 1
@@ -233,6 +234,12 @@ class Tasklist:
self.message = message self.message = message
self.messages[message.id] = self self.messages[message.id] = self
async def _update(self):
"""
Update the current message with the current page.
"""
await self.message.edit(embed=self.pages[self.current_page])
async def update(self, repost=None): async def update(self, repost=None):
""" """
Update the displayed tasklist. Update the displayed tasklist.
@@ -243,7 +250,7 @@ class Tasklist:
# Update data and make page list # Update data and make page list
self._refresh() self._refresh()
self._format_tasklist() await self._format_tasklist()
self._adjust_current_page() self._adjust_current_page()
if self.message and not repost: if self.message and not repost:
@@ -266,7 +273,8 @@ class Tasklist:
if not repost: if not repost:
try: try:
await self.message.edit(embed=self.pages[self.current_page]) # TODO: Refactor into update method
await self._update()
# Add or remove paging reactions as required # Add or remove paging reactions as required
should_have_paging = len(self.pages) > 1 should_have_paging = len(self.pages) > 1
@@ -387,8 +395,14 @@ class Tasklist:
Delete tasks from the task list Delete tasks from the task list
""" """
taskids = [self.tasklist[i].taskid for i in indexes] taskids = [self.tasklist[i].taskid for i in indexes]
return data.tasklist.delete_where(
taskid=taskids now = utc_now()
return data.tasklist.update_where(
{
'deleted_at': now,
'last_updated_at': now
},
taskid=taskids,
) )
def _edit_task(self, index, new_content): def _edit_task(self, index, new_content):
@@ -396,10 +410,12 @@ class Tasklist:
Update the provided task with the new content Update the provided task with the new content
""" """
taskid = self.tasklist[index].taskid taskid = self.tasklist[index].taskid
now = utc_now()
return data.tasklist.update_where( return data.tasklist.update_where(
{ {
'content': new_content, 'content': new_content,
'last_updated_at': datetime.datetime.utcnow() 'last_updated_at': now
}, },
taskid=taskid, taskid=taskid,
) )
@@ -409,13 +425,15 @@ class Tasklist:
Mark provided tasks as complete Mark provided tasks as complete
""" """
taskids = [self.tasklist[i].taskid for i in indexes] taskids = [self.tasklist[i].taskid for i in indexes]
now = utc_now()
return data.tasklist.update_where( return data.tasklist.update_where(
{ {
'complete': True, 'completed_at': now,
'last_updated_at': datetime.datetime.utcnow() 'last_updated_at': now
}, },
taskid=taskids, taskid=taskids,
complete=False, completed_at=NULL,
) )
def _uncheck_tasks(self, *indexes): def _uncheck_tasks(self, *indexes):
@@ -423,13 +441,15 @@ class Tasklist:
Mark provided tasks as incomplete Mark provided tasks as incomplete
""" """
taskids = [self.tasklist[i].taskid for i in indexes] taskids = [self.tasklist[i].taskid for i in indexes]
now = utc_now()
return data.tasklist.update_where( return data.tasklist.update_where(
{ {
'complete': False, 'completed_at': None,
'last_updated_at': datetime.datetime.utcnow() 'last_updated_at': now
}, },
taskid=taskids, taskid=taskids,
complete=True, completed_at=NOTNULL,
) )
def _index_range_parser(self, userstr): def _index_range_parser(self, userstr):
@@ -459,7 +479,7 @@ class Tasklist:
count = data.tasklist.select_one_where( count = data.tasklist.select_one_where(
select_columns=("COUNT(*)",), select_columns=("COUNT(*)",),
userid=self.member.id, userid=self.member.id,
_extra="AND last_updated_at > timezone('utc', NOW()) - INTERVAL '24h'" deleted_at=NULL
)[0] )[0]
# Fetch maximum allowed count # Fetch maximum allowed count
@@ -496,8 +516,8 @@ class Tasklist:
# Parse provided ranges # Parse provided ranges
indexes = self._index_range_parser(userstr) indexes = self._index_range_parser(userstr)
to_check = [index for index in indexes if not self.tasklist[index].complete] to_check = [index for index in indexes if not self.tasklist[index].completed_at]
to_uncheck = [index for index in indexes if self.tasklist[index].complete] to_uncheck = [index for index in indexes if self.tasklist[index].completed_at]
if to_uncheck: if to_uncheck:
self._uncheck_tasks(*to_uncheck) self._uncheck_tasks(*to_uncheck)
@@ -572,21 +592,21 @@ class Tasklist:
self.current_page %= len(self.pages) self.current_page %= len(self.pages)
if self.show_help: if self.show_help:
self.show_help = False self.show_help = False
self._format_tasklist() await self._format_tasklist()
await self.message.edit(embed=self.pages[self.current_page]) await self._update()
elif str_emoji == self.prev_emoji and user.id == self.member.id: elif str_emoji == self.prev_emoji and user.id == self.member.id:
self.current_page -= 1 self.current_page -= 1
self.current_page %= len(self.pages) self.current_page %= len(self.pages)
if self.show_help: if self.show_help:
self.show_help = False self.show_help = False
self._format_tasklist() await self._format_tasklist()
await self.message.edit(embed=self.pages[self.current_page]) await self._update()
elif str_emoji == self.cancel_emoji and user.id == self.member.id: elif str_emoji == self.cancel_emoji and user.id == self.member.id:
await self.deactivate(delete=True) await self.deactivate(delete=True)
elif str_emoji == self.question_emoji and user.id == self.member.id: elif str_emoji == self.question_emoji and user.id == self.member.id:
self.show_help = not self.show_help self.show_help = not self.show_help
self._format_tasklist() await self._format_tasklist()
await self.message.edit(embed=self.pages[self.current_page]) await self._update()
elif str_emoji == self.refresh_emoji and user.id == self.member.id: elif str_emoji == self.refresh_emoji and user.id == self.member.id:
await self.update() await self.update()
@@ -687,15 +707,3 @@ async def tasklist_message_handler(client, message):
async def tasklist_reaction_add_handler(client, reaction, user): async def tasklist_reaction_add_handler(client, reaction, user):
if user != client.user and reaction.message.id in Tasklist.messages: if user != client.user and reaction.message.id in Tasklist.messages:
await Tasklist.messages[reaction.message.id].handle_reaction(reaction, user, True) await Tasklist.messages[reaction.message.id].handle_reaction(reaction, user, True)
# @module.launch_task
# Commented because we don't actually need to expire these
async def tasklist_expiry_watchdog(client):
removed = data.tasklist.queries.expire_old_tasks()
if removed:
client.log(
"Remove {} stale todo tasks.".format(len(removed)),
context="TASKLIST_EXPIRY",
post=True
)

View File

@@ -2,23 +2,11 @@ from data import RowTable, Table
tasklist = RowTable( tasklist = RowTable(
'tasklist', 'tasklist',
('taskid', 'userid', 'content', 'complete', 'rewarded', 'created_at', 'last_updated_at'), ('taskid', 'userid', 'content', 'rewarded', 'created_at', 'completed_at', 'deleted_at', 'last_updated_at'),
'taskid' 'taskid'
) )
@tasklist.save_query
def expire_old_tasks():
with tasklist.conn:
with tasklist.conn.cursor() as curs:
curs.execute(
"DELETE FROM tasklist WHERE "
"last_updated_at < timezone('utc', NOW()) - INTERVAL '7d' "
"RETURNING *"
)
return curs.fetchall()
tasklist_channels = Table('tasklist_channels') tasklist_channels = Table('tasklist_channels')
tasklist_rewards = Table('tasklist_reward_history') tasklist_rewards = Table('tasklist_reward_history')

View File

@@ -7,6 +7,7 @@ from core import Lion
from settings import GuildSettings from settings import GuildSettings
from meta import client from meta import client
from data import NULL, tables from data import NULL, tables
from data.conditions import THIS_SHARD
from .module import module from .module import module
from .data import workout_sessions from .data import workout_sessions
@@ -170,7 +171,7 @@ async def workout_voice_tracker(client, member, before, after):
if member.bot: if member.bot:
return return
if member.id in client.objects['blacklisted_users']: if member.id in client.user_blacklist():
return return
if member.id in client.objects['ignored_members'][member.guild.id]: if member.id in client.objects['ignored_members'][member.guild.id]:
return return
@@ -226,7 +227,8 @@ async def load_workouts(client):
client.objects['current_workouts'] = {} # (guildid, userid) -> Row client.objects['current_workouts'] = {} # (guildid, userid) -> Row
# Process any incomplete workouts # Process any incomplete workouts
workouts = workout_sessions.fetch_rows_where( workouts = workout_sessions.fetch_rows_where(
duration=NULL duration=NULL,
guildid=THIS_SHARD
) )
count = 0 count = 0
for workout in workouts: for workout in workouts:

View File

@@ -26,21 +26,22 @@ async def embed_reply(ctx, desc, colour=discord.Colour.orange(), **kwargs):
@Context.util @Context.util
async def error_reply(ctx, error_str, **kwargs): async def error_reply(ctx, error_str, send_args={}, **kwargs):
""" """
Notify the user of a user level error. Notify the user of a user level error.
Typically, this will occur in a red embed, posted in the command channel. Typically, this will occur in a red embed, posted in the command channel.
""" """
embed = discord.Embed( embed = discord.Embed(
colour=discord.Colour.red(), colour=discord.Colour.red(),
description=error_str description=error_str,
**kwargs
) )
message = None message = None
try: try:
message = await ctx.ch.send( message = await ctx.ch.send(
embed=embed, embed=embed,
reference=ctx.msg.to_reference(fail_if_not_exists=False), reference=ctx.msg.to_reference(fail_if_not_exists=False),
**kwargs **send_args
) )
ctx.sent_messages.append(message) ctx.sent_messages.append(message)
return message return message

92
bot/utils/ratelimits.py Normal file
View File

@@ -0,0 +1,92 @@
import time
from cmdClient.lib import SafeCancellation
from cachetools import TTLCache
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.time()
self._last_full = False
@property
def overfull(self):
self._leak()
return self._level > self.max_level
def _leak(self):
if self._level:
elapsed = time.time() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.time()
def request(self):
self._leak()
if self._level + 1 > self.max_level + 1:
raise BucketOverFull
elif self._level + 1 > self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator

View File

@@ -182,6 +182,10 @@ async def find_channel(ctx, userstr, interactive=False, collection=None, chan_ty
# Create the collection to search from args or guild channels # Create the collection to search from args or guild channels
collection = collection if collection else ctx.guild.channels collection = collection if collection else ctx.guild.channels
if chan_type is not None: if chan_type is not None:
if chan_type == discord.ChannelType.text:
# Hack to support news channels as text channels
collection = [chan for chan in collection if isinstance(chan, discord.TextChannel)]
else:
collection = [chan for chan in collection if chan.type == chan_type] collection = [chan for chan in collection if chan.type == chan_type]
# If the user input was a number or possible channel mention, extract it # If the user input was a number or possible channel mention, extract it
@@ -413,7 +417,7 @@ async def find_message(ctx, msgid, chlist=None, ignore=[]):
async def _search_in_channel(channel: discord.TextChannel, msgid: int): async def _search_in_channel(channel: discord.TextChannel, msgid: int):
if channel.type != discord.ChannelType.text: if not isinstance(channel, discord.TextChannel):
return return
try: try:
message = await channel.fetch_message(msgid) message = await channel.fetch_message(msgid)

View File

@@ -1,6 +1,7 @@
[DEFAULT] [DEFAULT]
log_file = bot.log log_file = bot.log
log_channel = log_channel =
error_channel =
guild_log_channel = guild_log_channel =
prefix = ! prefix = !
@@ -10,4 +11,6 @@ owners = 413668234269818890, 389399222400712714
database = dbname=lionbot database = dbname=lionbot
data_appid = LionBot data_appid = LionBot
shard_count = 1
lion_sync_period = 60 lion_sync_period = 60

View File

@@ -1,19 +1,22 @@
/* DROP TYPE IF EXISTS SessionChannelType CASCADE; */ -- DROP TYPE IF EXISTS SessionChannelType CASCADE;
/* DROP TABLE IF EXISTS session_history CASCADE; */ -- DROP TABLE IF EXISTS session_history CASCADE;
/* DROP TABLE IF EXISTS current_sessions CASCADE; */ -- DROP TABLE IF EXISTS current_sessions CASCADE;
/* DROP FUNCTION IF EXISTS close_study_session; */ -- DROP FUNCTION IF EXISTS close_study_session(_guildid BIGINT, _userid BIGINT);
-- DROP FUNCTION IF EXISTS study_time_since(_guildid BIGINT, _userid BIGINT, _timestamp TIMESTAMPTZ)
-- DROP VIEW IF EXISTS current_sessions_totals CASCADE;
DROP VIEW IF EXISTS member_totals CASCADE;
DROP VIEW IF EXISTS member_ranks CASCADE;
DROP VIEW IF EXISTS current_study_badges CASCADE;
DROP VIEW IF EXISTS new_study_badges CASCADE;
/* DROP VIEW IF EXISTS current_sessions_totals CASCADE; */
/* DROP VIEW IF EXISTS member_totals CASCADE; */
/* DROP VIEW IF EXISTS member_ranks CASCADE; */
/* DROP VIEW IF EXISTS current_study_badges CASCADE; */
/* DROP VIEW IF EXISTS new_study_badges CASCADE; */
CREATE TYPE SessionChannelType AS ENUM ( CREATE TYPE SessionChannelType AS ENUM (
'STANDARD',
'ACCOUNTABILITY', 'ACCOUNTABILITY',
'RENTED', 'RENTED',
'EXTERNAL', 'EXTERNAL'
'MIGRATED'
); );
CREATE TABLE session_history( CREATE TABLE session_history(
@@ -74,7 +77,7 @@ AS $$
) SELECT ) SELECT
guildid, userid, channelid, channel_type, start_time, guildid, userid, channelid, channel_type, start_time,
total_duration, total_stream_duration, total_video_duration, total_live_duration, total_duration, total_stream_duration, total_video_duration, total_live_duration,
(total_duration * hourly_coins + live_duration * hourly_live_coins) / 60 (total_duration * hourly_coins + live_duration * hourly_live_coins) / 3600
FROM current_sesh FROM current_sesh
RETURNING * RETURNING *
) )
@@ -105,7 +108,7 @@ CREATE VIEW members_totals AS
*, *,
sesh.start_time AS session_start, sesh.start_time AS session_start,
tracked_time + COALESCE(sesh.total_duration, 0) AS total_tracked_time, tracked_time + COALESCE(sesh.total_duration, 0) AS total_tracked_time,
coins + COALESCE((sesh.total_duration * sesh.hourly_coins + sesh.live_duration * sesh.hourly_live_coins) / 60, 0) AS total_coins coins + COALESCE((sesh.total_duration * sesh.hourly_coins + sesh.live_duration * sesh.hourly_live_coins) / 3600, 0) AS total_coins
FROM members FROM members
LEFT JOIN current_sessions_totals sesh USING (guildid, userid); LEFT JOIN current_sessions_totals sesh USING (guildid, userid);
@@ -122,7 +125,7 @@ CREATE VIEW current_study_badges AS
*, *,
(SELECT r.badgeid (SELECT r.badgeid
FROM study_badges r FROM study_badges r
WHERE r.guildid = members_totals.guildid AND members_totals.tracked_time > r.required_time WHERE r.guildid = members_totals.guildid AND members_totals.total_tracked_time > r.required_time
ORDER BY r.required_time DESC ORDER BY r.required_time DESC
LIMIT 1) AS current_study_badgeid LIMIT 1) AS current_study_badgeid
FROM members_totals; FROM members_totals;
@@ -134,3 +137,44 @@ CREATE VIEW new_study_badges AS
WHERE WHERE
last_study_badgeid IS DISTINCT FROM current_study_badgeid last_study_badgeid IS DISTINCT FROM current_study_badgeid
ORDER BY guildid; ORDER BY guildid;
CREATE FUNCTION study_time_since(_guildid BIGINT, _userid BIGINT, _timestamp TIMESTAMPTZ)
RETURNS INTEGER
AS $$
BEGIN
RETURN (
SELECT
SUM(
CASE
WHEN start_time >= _timestamp THEN duration
ELSE EXTRACT(EPOCH FROM (end_time - _timestamp))
END
)
FROM (
SELECT
start_time,
duration,
(start_time + duration * interval '1 second') AS end_time
FROM session_history
WHERE
guildid=_guildid
AND userid=_userid
AND (start_time + duration * interval '1 second') >= _timestamp
UNION
SELECT
start_time,
EXTRACT(EPOCH FROM (NOW() - start_time)) AS duration,
NOW() AS end_time
FROM current_sessions
WHERE
guildid=_guildid
AND userid=_userid
) AS sessions
);
END;
$$ LANGUAGE PLPGSQL;
ALTER TABLE guild_config ADD COLUMN daily_study_cap INTEGER;
INSERT INTO VersionHistory (version, author) VALUES (6, 'v5-v6 Migration');

View File

@@ -0,0 +1,76 @@
-- Improved tasklist statistics
ALTER TABLE tasklist
ADD COLUMN completed_at TIMESTAMPTZ,
ADD COLUMN deleted_at TIMESTAMPTZ,
ALTER COLUMN created_at TYPE TIMESTAMPTZ USING created_at AT TIME ZONE 'UTC',
ALTER COLUMN last_updated_at TYPE TIMESTAMPTZ USING created_at AT TIME ZONE 'UTC';
UPDATE tasklist SET deleted_at = NOW() WHERE last_updated_at < NOW() - INTERVAL '24h';
UPDATE tasklist SET completed_at = last_updated_at WHERE complete;
ALTER TABLE tasklist
DROP COLUMN complete;
-- New member profile tags
CREATE TABLE member_profile_tags(
tagid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
tag TEXT NOT NULL,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid)
);
CREATE INDEX member_profile_tags_members ON member_profile_tags (guildid, userid);
-- New member weekly and monthly goals
CREATE TABLE member_weekly_goals(
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
weekid INTEGER NOT NULL, -- Epoch time of the start of the UTC week
study_goal INTEGER,
task_goal INTEGER,
_timestamp TIMESTAMPTZ DEFAULT now(),
PRIMARY KEY (guildid, userid, weekid),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_weekly_goals_members ON member_weekly_goals (guildid, userid);
CREATE TABLE member_weekly_goal_tasks(
taskid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
weekid INTEGER NOT NULL,
content TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT FALSE,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (weekid, guildid, userid) REFERENCES member_weekly_goals (weekid, guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_weekly_goal_tasks_members_weekly ON member_weekly_goal_tasks (guildid, userid, weekid);
CREATE TABLE member_monthly_goals(
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
monthid INTEGER NOT NULL, -- Epoch time of the start of the UTC month
study_goal INTEGER,
task_goal INTEGER,
_timestamp TIMESTAMPTZ DEFAULT now(),
PRIMARY KEY (guildid, userid, monthid),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_monthly_goals_members ON member_monthly_goals (guildid, userid);
CREATE TABLE member_monthly_goal_tasks(
taskid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
monthid INTEGER NOT NULL,
content TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT FALSE,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (monthid, guildid, userid) REFERENCES member_monthly_goals (monthid, guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_monthly_goal_tasks_members_monthly ON member_monthly_goal_tasks (guildid, userid, monthid);
INSERT INTO VersionHistory (version, author) VALUES (7, 'v6-v7 migration');

View File

@@ -4,7 +4,7 @@ CREATE TABLE VersionHistory(
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL, time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
author TEXT author TEXT
); );
INSERT INTO VersionHistory (version, author) VALUES (5, 'Initial Creation'); INSERT INTO VersionHistory (version, author) VALUES (7, 'Initial Creation');
CREATE OR REPLACE FUNCTION update_timestamp_column() CREATE OR REPLACE FUNCTION update_timestamp_column()
@@ -78,7 +78,7 @@ CREATE TABLE guild_config(
returning_message TEXT, returning_message TEXT,
starting_funds INTEGER, starting_funds INTEGER,
persist_roles BOOLEAN, persist_roles BOOLEAN,
max_daily_study INTEGER daily_study_cap INTEGER
); );
CREATE TABLE ignored_members( CREATE TABLE ignored_members(
@@ -135,10 +135,11 @@ CREATE TABLE tasklist(
taskid SERIAL PRIMARY KEY, taskid SERIAL PRIMARY KEY,
userid BIGINT NOT NULL, userid BIGINT NOT NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
complete BOOL DEFAULT FALSE,
rewarded BOOL DEFAULT FALSE, rewarded BOOL DEFAULT FALSE,
created_at TIMESTAMP DEFAULT (now() at time zone 'utc'), deleted_at TIMESTAMPTZ,
last_updated_at TIMESTAMP DEFAULT (now() at time zone 'utc') completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ,
last_updated_at TIMESTAMPTZ
); );
CREATE INDEX tasklist_users ON tasklist (userid); CREATE INDEX tasklist_users ON tasklist (userid);
@@ -412,12 +413,13 @@ update_timestamp_column();
-- Study Session Data {{{ -- Study Session Data {{{
CREATE TYPE SessionChannelType AS ENUM ( CREATE TYPE SessionChannelType AS ENUM (
'STANDARD',
'ACCOUNTABILITY', 'ACCOUNTABILITY',
'RENTED', 'RENTED',
'EXTERNAL', 'EXTERNAL',
'MIGRATED'
); );
CREATE TABLE session_history( CREATE TABLE session_history(
sessionid SERIAL PRIMARY KEY, sessionid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL, guildid BIGINT NOT NULL,
@@ -453,6 +455,43 @@ CREATE TABLE current_sessions(
CREATE UNIQUE INDEX current_session_members ON current_sessions (guildid, userid); CREATE UNIQUE INDEX current_session_members ON current_sessions (guildid, userid);
CREATE FUNCTION study_time_since(_guildid BIGINT, _userid BIGINT, _timestamp TIMESTAMPTZ)
RETURNS INTEGER
AS $$
BEGIN
RETURN (
SELECT
SUM(
CASE
WHEN start_time >= _timestamp THEN duration
ELSE EXTRACT(EPOCH FROM (end_time - _timestamp))
END
)
FROM (
SELECT
start_time,
duration,
(start_time + duration * interval '1 second') AS end_time
FROM session_history
WHERE
guildid=_guildid
AND userid=_userid
AND (start_time + duration * interval '1 second') >= _timestamp
UNION
SELECT
start_time,
EXTRACT(EPOCH FROM (NOW() - start_time)) AS duration,
NOW() AS end_time
FROM current_sessions
WHERE
guildid=_guildid
AND userid=_userid
) AS sessions
);
END;
$$ LANGUAGE PLPGSQL;
CREATE FUNCTION close_study_session(_guildid BIGINT, _userid BIGINT) CREATE FUNCTION close_study_session(_guildid BIGINT, _userid BIGINT)
RETURNS SETOF members RETURNS SETOF members
AS $$ AS $$
@@ -476,7 +515,7 @@ AS $$
) SELECT ) SELECT
guildid, userid, channelid, channel_type, start_time, guildid, userid, channelid, channel_type, start_time,
total_duration, total_stream_duration, total_video_duration, total_live_duration, total_duration, total_stream_duration, total_video_duration, total_live_duration,
(total_duration * hourly_coins + live_duration * hourly_live_coins) / 60 (total_duration * hourly_coins + live_duration * hourly_live_coins) / 3600
FROM current_sesh FROM current_sesh
RETURNING * RETURNING *
) )
@@ -506,7 +545,7 @@ CREATE VIEW members_totals AS
*, *,
sesh.start_time AS session_start, sesh.start_time AS session_start,
tracked_time + COALESCE(sesh.total_duration, 0) AS total_tracked_time, tracked_time + COALESCE(sesh.total_duration, 0) AS total_tracked_time,
coins + COALESCE((sesh.total_duration * sesh.hourly_coins + sesh.live_duration * sesh.hourly_live_coins) / 60, 0) AS total_coins coins + COALESCE((sesh.total_duration * sesh.hourly_coins + sesh.live_duration * sesh.hourly_live_coins) / 3600, 0) AS total_coins
FROM members FROM members
LEFT JOIN current_sessions_totals sesh USING (guildid, userid); LEFT JOIN current_sessions_totals sesh USING (guildid, userid);
@@ -525,7 +564,7 @@ CREATE VIEW current_study_badges AS
*, *,
(SELECT r.badgeid (SELECT r.badgeid
FROM study_badges r FROM study_badges r
WHERE r.guildid = members_totals.guildid AND members_totals.tracked_time > r.required_time WHERE r.guildid = members_totals.guildid AND members_totals.total_tracked_time > r.required_time
ORDER BY r.required_time DESC ORDER BY r.required_time DESC
LIMIT 1) AS current_study_badgeid LIMIT 1) AS current_study_badgeid
FROM members_totals; FROM members_totals;
@@ -644,4 +683,67 @@ CREATE TABLE past_member_roles(
CREATE INDEX member_role_persistence_members ON past_member_roles (guildid, userid); CREATE INDEX member_role_persistence_members ON past_member_roles (guildid, userid);
-- }}} -- }}}
-- Member profile tags {{{
CREATE TABLE member_profile_tags(
tagid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
tag TEXT NOT NULL,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid)
);
CREATE INDEX member_profile_tags_members ON member_profile_tags (guildid, userid);
-- }}}
-- Member goals {{{
CREATE TABLE member_weekly_goals(
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
weekid INTEGER NOT NULL, -- Epoch time of the start of the UTC week
study_goal INTEGER,
task_goal INTEGER,
_timestamp TIMESTAMPTZ DEFAULT now(),
PRIMARY KEY (guildid, userid, weekid),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_weekly_goals_members ON member_weekly_goals (guildid, userid);
CREATE TABLE member_weekly_goal_tasks(
taskid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
weekid INTEGER NOT NULL,
content TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT FALSE,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (weekid, guildid, userid) REFERENCES member_weekly_goals (weekid, guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_weekly_goal_tasks_members_weekly ON member_weekly_goal_tasks (guildid, userid, weekid);
CREATE TABLE member_monthly_goals(
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
monthid INTEGER NOT NULL, -- Epoch time of the start of the UTC month
study_goal INTEGER,
task_goal INTEGER,
_timestamp TIMESTAMPTZ DEFAULT now(),
PRIMARY KEY (guildid, userid, monthid),
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_monthly_goals_members ON member_monthly_goals (guildid, userid);
CREATE TABLE member_monthly_goal_tasks(
taskid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
userid BIGINT NOT NULL,
monthid INTEGER NOT NULL,
content TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT FALSE,
_timestamp TIMESTAMPTZ DEFAULT now(),
FOREIGN KEY (monthid, guildid, userid) REFERENCES member_monthly_goals (monthid, guildid, userid) ON DELETE CASCADE
);
CREATE INDEX member_monthly_goal_tasks_members_monthly ON member_monthly_goal_tasks (guildid, userid, monthid);
-- }}}
-- vim: set fdm=marker: -- vim: set fdm=marker: