449 lines
17 KiB
Python
449 lines
17 KiB
Python
import asyncio
|
|
import datetime
|
|
import collections
|
|
import traceback
|
|
import logging
|
|
import discord
|
|
from typing import Dict
|
|
from discord.utils import sleep_until
|
|
|
|
from meta import client
|
|
from data import NULL, NOTNULL, tables
|
|
from data.conditions import LEQ
|
|
from settings import GuildSettings
|
|
|
|
from .TimeSlot import TimeSlot
|
|
from .lib import utc_now
|
|
from .data import accountability_rooms, accountability_members
|
|
from .module import module
|
|
|
|
|
|
voice_ignore_lock = asyncio.Lock()
|
|
|
|
|
|
class AccountabilityGuild:
|
|
__slots__ = ('guildid', 'current_slot', 'upcoming_slot')
|
|
|
|
cache: Dict[int, 'AccountabilityGuild'] = {} # Map guildid -> AccountabilityGuild
|
|
|
|
def __init__(self, guildid):
|
|
self.guildid = guildid
|
|
self.current_slot = None
|
|
self.upcoming_slot = None
|
|
|
|
self.cache[guildid] = self
|
|
|
|
@property
|
|
def guild(self):
|
|
return client.get_guild(self.guildid)
|
|
|
|
@property
|
|
def guild_settings(self):
|
|
return GuildSettings(self.guildid)
|
|
|
|
def advance(self):
|
|
self.current_slot = self.upcoming_slot
|
|
self.upcoming_slot = None
|
|
|
|
|
|
async def open_next(start_time):
|
|
"""
|
|
Open all the upcoming accountability rooms, and fire channel notify.
|
|
To be executed ~5 minutes to the hour.
|
|
"""
|
|
# Pre-fetch the new slot data, also populating the table caches
|
|
room_data = accountability_rooms.fetch_rows_where(
|
|
start_at=start_time
|
|
)
|
|
guild_rows = {row.guildid: row for row in room_data}
|
|
member_data = accountability_members.fetch_rows_where(
|
|
slotid=[row.slotid for row in room_data]
|
|
) if room_data else []
|
|
slot_memberids = collections.defaultdict(list)
|
|
for row in member_data:
|
|
slot_memberids[row.slotid].append(row.userid)
|
|
print(room_data, member_data)
|
|
|
|
# Open a new slot in each accountability guild
|
|
to_update = [] # Cache of slot update data to be applied at the end
|
|
for aguild in list(AccountabilityGuild.cache.values()):
|
|
guild = aguild.guild
|
|
if guild:
|
|
# Initialise next TimeSlot
|
|
slot = TimeSlot(
|
|
guild,
|
|
start_time,
|
|
data=guild_rows.get(aguild.guildid, None)
|
|
)
|
|
slot.load(memberids=slot_memberids[slot.data.slotid] if slot.data else None)
|
|
|
|
if not slot.category:
|
|
# Log and unload guild
|
|
aguild.guild_settings.event_log.log(
|
|
"The accountability category couldn't be found!\n"
|
|
"Shutting down the accountability system in this server.\n"
|
|
"To re-activate, please reconfigure `config accountability_category`."
|
|
)
|
|
AccountabilityGuild.cache.pop(aguild.guildid, None)
|
|
await slot.cancel()
|
|
continue
|
|
elif not slot.lobby:
|
|
# TODO: Consider putting in TimeSlot.open().. or even better in accountability_lobby.create()
|
|
# Create a new lobby
|
|
try:
|
|
channel = await guild.create_text_channel(
|
|
name="accountability-lobby",
|
|
category=slot.category,
|
|
reason="Automatic creation of accountability lobby."
|
|
)
|
|
aguild.guild_settings.accountability_lobby.value = channel
|
|
slot.lobby = channel
|
|
except discord.HTTPException:
|
|
# Event log failure and skip session
|
|
aguild.guild_settings.event_log.log(
|
|
"Failed to create the accountability lobby text channel.\n"
|
|
"Please set the lobby channel manually with `config`."
|
|
)
|
|
await slot.cancel()
|
|
continue
|
|
|
|
# Event log creation
|
|
aguild.guild_settings.event_log.log(
|
|
"Automatically created an accountability lobby channel {}.".format(channel.mention)
|
|
)
|
|
|
|
results = await slot.open()
|
|
if results is None:
|
|
# Couldn't open the channel for some reason.
|
|
# Should already have been logged in `open`.
|
|
# Skip this session
|
|
await slot.cancel()
|
|
continue
|
|
elif slot.data:
|
|
to_update.append((results[0], results[1], slot.data.slotid))
|
|
|
|
# Time slot should now be open and ready to start
|
|
aguild.upcoming_slot = slot
|
|
else:
|
|
# Unload guild from cache
|
|
AccountabilityGuild.cache.pop(aguild.guildid, None)
|
|
|
|
# Update slot data
|
|
if to_update:
|
|
accountability_rooms.update_many(
|
|
*to_update,
|
|
set_keys=('channelid', 'messageid'),
|
|
where_keys=('slotid',)
|
|
)
|
|
|
|
|
|
async def turnover():
|
|
"""
|
|
Switchover from the current accountability rooms to the next ones.
|
|
To be executed as close as possible to the hour.
|
|
"""
|
|
now = utc_now()
|
|
|
|
# Open event lock so we don't read voice channel movement
|
|
async with voice_ignore_lock:
|
|
# Update session data for completed sessions
|
|
last_slots = [
|
|
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
|
|
if aguild.current_slot is not None
|
|
]
|
|
|
|
to_update = [
|
|
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()), None, mem.slotid, mem.userid)
|
|
for slot in last_slots for mem in slot.members.values()
|
|
if mem.data.last_joined_at
|
|
]
|
|
if to_update:
|
|
accountability_members.update_many(
|
|
*to_update,
|
|
set_keys=('duration', 'last_joined_at'),
|
|
where_keys=('slotid', 'userid'),
|
|
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
|
)
|
|
|
|
# Rotate guild sessions
|
|
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
|
|
|
|
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
|
|
# We could break up the session starting?
|
|
|
|
# Move members of the next session over to the session channel
|
|
# This includes any members of the session just complete
|
|
current_slots = [
|
|
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
|
|
if aguild.current_slot is not None
|
|
]
|
|
movement_tasks = (
|
|
mem.member.edit(
|
|
voice_channel=slot.channel,
|
|
reason="Moving to booked accountability session."
|
|
)
|
|
for slot in current_slots
|
|
for mem in slot.members.values()
|
|
if mem.member.voice and mem.member.voice.channel != slot.channel
|
|
)
|
|
# We return exceptions here to ignore any permission issues that occur with moving members.
|
|
# It's also possible (likely) that members will move while we are moving other members
|
|
# Returning the exceptions ensures that they are explicitly ignored
|
|
await asyncio.gather(
|
|
*movement_tasks,
|
|
return_exceptions=True
|
|
)
|
|
|
|
# Close all completed rooms, update data
|
|
await asyncio.gather(*(slot.close() for slot in last_slots))
|
|
update_slots = [slot.data.slotid for slot in last_slots if slot.data]
|
|
if update_slots:
|
|
accountability_rooms.update_where(
|
|
{'closed_at': utc_now()},
|
|
slotid=update_slots
|
|
)
|
|
|
|
# Update session data of all members in new channels
|
|
member_session_data = [
|
|
(0, slot.start_time, mem.slotid, mem.userid)
|
|
for slot in current_slots
|
|
for mem in slot.members.values()
|
|
if mem.member.voice and mem.member.voice.channel == slot.channel
|
|
]
|
|
if member_session_data:
|
|
accountability_members.update_many(
|
|
*member_session_data,
|
|
set_keys=('duration', 'last_joined_at'),
|
|
where_keys=('slotid', 'userid'),
|
|
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
|
)
|
|
|
|
# Start all the current rooms
|
|
await asyncio.gather(
|
|
*(slot.start() for slot in current_slots)
|
|
)
|
|
|
|
|
|
@client.add_after_event('voice_state_update')
|
|
async def room_watchdog(client, member, before, after):
|
|
"""
|
|
Update session data when a member joins or leaves an accountability room.
|
|
Ignores events that occur while `voice_ignore_lock` is held.
|
|
"""
|
|
if not voice_ignore_lock.locked() and before.channel != after.channel:
|
|
aguild = AccountabilityGuild.cache.get(member.guild.id)
|
|
if aguild and aguild.current_slot and aguild.current_slot.channel:
|
|
slot = aguild.current_slot
|
|
if member.id in slot.members:
|
|
if after.channel and after.channel.id != slot.channel.id:
|
|
# Summon them back!
|
|
asyncio.create_task(member.edit(voice_channel=slot.channel))
|
|
|
|
slot_member = slot.members[member.id]
|
|
data = slot_member.data
|
|
|
|
if before.channel and before.channel.id == slot.channel.id:
|
|
# Left accountability room
|
|
with data.batch_update():
|
|
data.duration += int((utc_now() - data.last_joined_at).total_seconds())
|
|
data.last_joined_at = None
|
|
await slot.update_status()
|
|
elif after.channel and after.channel.id == slot.channel.id:
|
|
# Joined accountability room
|
|
with data.batch_update():
|
|
data.last_joined_at = utc_now()
|
|
await slot.update_status()
|
|
|
|
|
|
async def _accountability_loop():
|
|
"""
|
|
Runloop in charge of executing the room update tasks at the correct times.
|
|
"""
|
|
# Wait until ready
|
|
while not client.is_ready():
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Calculate starting next_time
|
|
# Assume the resume logic has taken care of all events/tasks before current_time
|
|
now = utc_now()
|
|
if now.minute < 55:
|
|
next_time = now.replace(minute=55, second=0, microsecond=0)
|
|
else:
|
|
next_time = now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
|
|
|
|
# Executor loop
|
|
while True:
|
|
# TODO: (FUTURE) handle cases where we actually execute much late than expected
|
|
await sleep_until(next_time)
|
|
if next_time.minute == 55:
|
|
next_time = next_time + datetime.timedelta(minutes=5)
|
|
# Open next sessions
|
|
try:
|
|
await open_next(next_time)
|
|
except Exception:
|
|
# Unknown exception. Catch it so the loop doesn't die.
|
|
client.log(
|
|
"Error while opening new accountability rooms! "
|
|
"Exception traceback follows.\n{}".format(
|
|
traceback.format_exc()
|
|
),
|
|
context="ACCOUNTABILITY_LOOP",
|
|
level=logging.ERROR
|
|
)
|
|
elif next_time.minute == 0:
|
|
# Start new sessions
|
|
try:
|
|
await turnover()
|
|
except Exception:
|
|
# Unknown exception. Catch it so the loop doesn't die.
|
|
client.log(
|
|
"Error while starting accountability rooms! "
|
|
"Exception traceback follows.\n{}".format(
|
|
traceback.format_exc()
|
|
),
|
|
context="ACCOUNTABILITY_LOOP",
|
|
level=logging.ERROR
|
|
)
|
|
next_time = next_time + datetime.timedelta(minutes=55)
|
|
|
|
|
|
async def _accountability_system_resume():
|
|
"""
|
|
Logic for starting the accountability system from cold.
|
|
Essentially, session and state resume logic.
|
|
"""
|
|
now = utc_now()
|
|
|
|
# Fetch the open room data, only takes into account currently running sessions.
|
|
# May include sessions that were never opened, or opened but never started
|
|
# Does not include sessions that were opened that start on the next hour
|
|
open_room_data = accountability_rooms.fetch_rows_where(
|
|
closed_at=NULL,
|
|
start_at=LEQ(now),
|
|
_extra="ORDER BY start_at ASC"
|
|
)
|
|
|
|
if open_room_data:
|
|
# Extract member data of these rows
|
|
member_data = accountability_members.fetch_rows_where(
|
|
slotid=[row.slotid for row in open_room_data]
|
|
)
|
|
slot_members = collections.defaultdict(list)
|
|
for row in member_data:
|
|
slot_members[row.slotid].append(row)
|
|
|
|
# Filter these into expired rooms and current rooms
|
|
expired_room_data = []
|
|
current_room_data = []
|
|
for row in open_room_data:
|
|
if row.start_at + datetime.timedelta(hours=1) < now:
|
|
expired_room_data.append(row)
|
|
else:
|
|
current_room_data.append(row)
|
|
|
|
session_updates = []
|
|
|
|
# TODO URGENT: Batch room updates here
|
|
|
|
# Expire the expired rooms
|
|
for row in expired_room_data:
|
|
if row.channelid is None or row.messageid is None:
|
|
# TODO refunds here
|
|
# If the rooms were never opened, close them and skip
|
|
row.closed_at = now
|
|
else:
|
|
# If the rooms were opened and maybe started, make optimistic guesses on session data and close.
|
|
session_end = row.start_at + datetime.timedelta(hours=1)
|
|
session_updates.extend(
|
|
(mow.duration + int((session_end - mow.last_joined_at).total_seconds()),
|
|
None, mow.slotid, mow.userid)
|
|
for mow in slot_members[row.slotid] if mow.last_joined_at
|
|
)
|
|
slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load(
|
|
memberids=[mow.userid for mow in slot_members[row.slotid]]
|
|
)
|
|
row.closed_at = now
|
|
try:
|
|
await slot.close()
|
|
except discord.HTTPException:
|
|
pass
|
|
|
|
# Load the in-progress room data
|
|
if current_room_data:
|
|
async with voice_ignore_lock:
|
|
current_hour = now.replace(minute=0, second=0, microsecond=0)
|
|
await open_next(current_hour)
|
|
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
|
|
|
|
current_slots = [
|
|
aguild.current_slot
|
|
for aguild in AccountabilityGuild.cache.values()
|
|
if aguild.current_slot
|
|
]
|
|
|
|
session_updates.extend(
|
|
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()),
|
|
None, mem.slotid, mem.userid)
|
|
for slot in current_slots
|
|
for mem in slot.members.values()
|
|
if mem.data.last_joined_at and mem.member not in slot.channel.members
|
|
)
|
|
|
|
session_updates.extend(
|
|
(mem.data.duration,
|
|
now, mem.slotid, mem.userid)
|
|
for slot in current_slots
|
|
for mem in slot.members.values()
|
|
if not mem.data.last_joined_at and mem.member in slot.channel.members
|
|
)
|
|
|
|
if session_updates:
|
|
accountability_members.update_many(
|
|
*session_updates,
|
|
set_keys=('duration', 'last_joined_at'),
|
|
where_keys=('slotid', 'userid'),
|
|
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
|
)
|
|
|
|
await asyncio.gather(
|
|
*(aguild.current_slot.start()
|
|
for aguild in AccountabilityGuild.cache.values() if aguild.current_slot)
|
|
)
|
|
else:
|
|
if session_updates:
|
|
accountability_members.update_many(
|
|
*session_updates,
|
|
set_keys=('duration', 'last_joined_at'),
|
|
where_keys=('slotid', 'userid'),
|
|
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
|
)
|
|
|
|
# If we are in the last five minutes of the hour, open new rooms.
|
|
# Note that these may already have been opened, or they may not have been.
|
|
if now.minute >= 55:
|
|
await open_next(
|
|
now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
|
|
)
|
|
|
|
|
|
@module.launch_task
|
|
async def launch_accountability_system(client):
|
|
"""
|
|
Launcher for the accountability system.
|
|
Resumes saved sessions, and starts the accountability loop.
|
|
"""
|
|
# Load the AccountabilityGuild cache
|
|
guilds = tables.guild_config.fetch_rows_where(
|
|
accountability_category=NOTNULL
|
|
)
|
|
[AccountabilityGuild(guild.guildid) for guild in guilds]
|
|
await _accountability_system_resume()
|
|
asyncio.create_task(_accountability_loop())
|
|
|
|
|
|
async def unload_accountability(client):
|
|
"""
|
|
Save the current sessions and cancel the runloop in preparation for client shutdown.
|
|
"""
|
|
...
|