Files
croccybot/bot/modules/accountability/tracker.py
Conatum b2aa651eaa fix (rooms): Harden against race conditions.
Add locking to room init, turnover, and cancellation.
Add cleanup of nonexistent members in slot init.
Fix an issue where members were being charged for cancelling rooms.
2021-09-24 21:12:12 +03:00

464 lines
17 KiB
Python

import asyncio
import datetime
import collections
import traceback
import logging
import discord
from typing import Dict
from discord.utils import sleep_until
from meta import client
from data import NULL, NOTNULL, tables
from data.conditions import LEQ
from settings import GuildSettings
from .TimeSlot import TimeSlot
from .lib import utc_now
from .data import accountability_rooms, accountability_members
from .module import module
voice_ignore_lock = asyncio.Lock()
room_lock = asyncio.Lock()
def locker(lock):
"""
Function decorator to wrap the function in a provided Lock
"""
def decorator(func):
async def wrapped(*args, **kwargs):
async with lock:
return await func(*args, **kwargs)
return wrapped
return decorator
class AccountabilityGuild:
__slots__ = ('guildid', 'current_slot', 'upcoming_slot')
cache: Dict[int, 'AccountabilityGuild'] = {} # Map guildid -> AccountabilityGuild
def __init__(self, guildid):
self.guildid = guildid
self.current_slot = None
self.upcoming_slot = None
self.cache[guildid] = self
@property
def guild(self):
return client.get_guild(self.guildid)
@property
def guild_settings(self):
return GuildSettings(self.guildid)
def advance(self):
self.current_slot = self.upcoming_slot
self.upcoming_slot = None
async def open_next(start_time):
"""
Open all the upcoming accountability rooms, and fire channel notify.
To be executed ~5 minutes to the hour.
"""
# Pre-fetch the new slot data, also populating the table caches
room_data = accountability_rooms.fetch_rows_where(
start_at=start_time
)
guild_rows = {row.guildid: row for row in room_data}
member_data = accountability_members.fetch_rows_where(
slotid=[row.slotid for row in room_data]
) if room_data else []
slot_memberids = collections.defaultdict(list)
for row in member_data:
slot_memberids[row.slotid].append(row.userid)
print(room_data, member_data)
# Open a new slot in each accountability guild
to_update = [] # Cache of slot update data to be applied at the end
for aguild in list(AccountabilityGuild.cache.values()):
guild = aguild.guild
if guild:
# Initialise next TimeSlot
slot = TimeSlot(
guild,
start_time,
data=guild_rows.get(aguild.guildid, None)
)
slot.load(memberids=slot_memberids[slot.data.slotid] if slot.data else None)
if not slot.category:
# Log and unload guild
aguild.guild_settings.event_log.log(
"The accountability category couldn't be found!\n"
"Shutting down the accountability system in this server.\n"
"To re-activate, please reconfigure `config accountability_category`."
)
AccountabilityGuild.cache.pop(aguild.guildid, None)
await slot.cancel()
continue
elif not slot.lobby:
# TODO: Consider putting in TimeSlot.open().. or even better in accountability_lobby.create()
# Create a new lobby
try:
channel = await guild.create_text_channel(
name="accountability-lobby",
category=slot.category,
reason="Automatic creation of accountability lobby."
)
aguild.guild_settings.accountability_lobby.value = channel
slot.lobby = channel
except discord.HTTPException:
# Event log failure and skip session
aguild.guild_settings.event_log.log(
"Failed to create the accountability lobby text channel.\n"
"Please set the lobby channel manually with `config`."
)
await slot.cancel()
continue
# Event log creation
aguild.guild_settings.event_log.log(
"Automatically created an accountability lobby channel {}.".format(channel.mention)
)
results = await slot.open()
if results is None:
# Couldn't open the channel for some reason.
# Should already have been logged in `open`.
# Skip this session
await slot.cancel()
continue
elif slot.data:
to_update.append((results[0], results[1], slot.data.slotid))
# Time slot should now be open and ready to start
aguild.upcoming_slot = slot
else:
# Unload guild from cache
AccountabilityGuild.cache.pop(aguild.guildid, None)
# Update slot data
if to_update:
accountability_rooms.update_many(
*to_update,
set_keys=('channelid', 'messageid'),
where_keys=('slotid',)
)
async def turnover():
"""
Switchover from the current accountability rooms to the next ones.
To be executed as close as possible to the hour.
"""
now = utc_now()
# Open event lock so we don't read voice channel movement
async with voice_ignore_lock:
# Update session data for completed sessions
last_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None
]
to_update = [
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()), None, mem.slotid, mem.userid)
for slot in last_slots for mem in slot.members.values()
if mem.data.last_joined_at
]
if to_update:
accountability_members.update_many(
*to_update,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# Rotate guild sessions
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
# We could break up the session starting?
# Move members of the next session over to the session channel
# This includes any members of the session just complete
current_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None
]
movement_tasks = (
mem.member.edit(
voice_channel=slot.channel,
reason="Moving to booked accountability session."
)
for slot in current_slots
for mem in slot.members.values()
if mem.data and mem.member and mem.member.voice and mem.member.voice.channel != slot.channel
)
# We return exceptions here to ignore any permission issues that occur with moving members.
# It's also possible (likely) that members will move while we are moving other members
# Returning the exceptions ensures that they are explicitly ignored
await asyncio.gather(
*movement_tasks,
return_exceptions=True
)
# Close all completed rooms, update data
await asyncio.gather(*(slot.close() for slot in last_slots))
update_slots = [slot.data.slotid for slot in last_slots if slot.data]
if update_slots:
accountability_rooms.update_where(
{'closed_at': utc_now()},
slotid=update_slots
)
# Update session data of all members in new channels
member_session_data = [
(0, slot.start_time, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if mem.member.voice and mem.member.voice.channel == slot.channel
]
if member_session_data:
accountability_members.update_many(
*member_session_data,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# Start all the current rooms
await asyncio.gather(
*(slot.start() for slot in current_slots)
)
@client.add_after_event('voice_state_update')
async def room_watchdog(client, member, before, after):
"""
Update session data when a member joins or leaves an accountability room.
Ignores events that occur while `voice_ignore_lock` is held.
"""
if not voice_ignore_lock.locked() and before.channel != after.channel:
aguild = AccountabilityGuild.cache.get(member.guild.id)
if aguild and aguild.current_slot and aguild.current_slot.channel:
slot = aguild.current_slot
if member.id in slot.members:
if after.channel and after.channel.id != slot.channel.id:
# Summon them back!
asyncio.create_task(member.edit(voice_channel=slot.channel))
slot_member = slot.members[member.id]
data = slot_member.data
if before.channel and before.channel.id == slot.channel.id:
# Left accountability room
with data.batch_update():
data.duration += int((utc_now() - data.last_joined_at).total_seconds())
data.last_joined_at = None
await slot.update_status()
elif after.channel and after.channel.id == slot.channel.id:
# Joined accountability room
with data.batch_update():
data.last_joined_at = utc_now()
await slot.update_status()
async def _accountability_loop():
"""
Runloop in charge of executing the room update tasks at the correct times.
"""
# Wait until ready
while not client.is_ready():
await asyncio.sleep(0.1)
# Calculate starting next_time
# Assume the resume logic has taken care of all events/tasks before current_time
now = utc_now()
if now.minute < 55:
next_time = now.replace(minute=55, second=0, microsecond=0)
else:
next_time = now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
# Executor loop
while True:
# TODO: (FUTURE) handle cases where we actually execute much late than expected
await sleep_until(next_time)
if next_time.minute == 55:
next_time = next_time + datetime.timedelta(minutes=5)
# Open next sessions
try:
async with room_lock:
await open_next(next_time)
except Exception:
# Unknown exception. Catch it so the loop doesn't die.
client.log(
"Error while opening new accountability rooms! "
"Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="ACCOUNTABILITY_LOOP",
level=logging.ERROR
)
elif next_time.minute == 0:
# Start new sessions
try:
async with room_lock:
await turnover()
except Exception:
# Unknown exception. Catch it so the loop doesn't die.
client.log(
"Error while starting accountability rooms! "
"Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="ACCOUNTABILITY_LOOP",
level=logging.ERROR
)
next_time = next_time + datetime.timedelta(minutes=55)
async def _accountability_system_resume():
"""
Logic for starting the accountability system from cold.
Essentially, session and state resume logic.
"""
now = utc_now()
# Fetch the open room data, only takes into account currently running sessions.
# May include sessions that were never opened, or opened but never started
# Does not include sessions that were opened that start on the next hour
open_room_data = accountability_rooms.fetch_rows_where(
closed_at=NULL,
start_at=LEQ(now),
_extra="ORDER BY start_at ASC"
)
if open_room_data:
# Extract member data of these rows
member_data = accountability_members.fetch_rows_where(
slotid=[row.slotid for row in open_room_data]
)
slot_members = collections.defaultdict(list)
for row in member_data:
slot_members[row.slotid].append(row)
# Filter these into expired rooms and current rooms
expired_room_data = []
current_room_data = []
for row in open_room_data:
if row.start_at + datetime.timedelta(hours=1) < now:
expired_room_data.append(row)
else:
current_room_data.append(row)
session_updates = []
# TODO URGENT: Batch room updates here
# Expire the expired rooms
for row in expired_room_data:
if row.channelid is None or row.messageid is None:
# TODO refunds here
# If the rooms were never opened, close them and skip
row.closed_at = now
else:
# If the rooms were opened and maybe started, make optimistic guesses on session data and close.
session_end = row.start_at + datetime.timedelta(hours=1)
session_updates.extend(
(mow.duration + int((session_end - mow.last_joined_at).total_seconds()),
None, mow.slotid, mow.userid)
for mow in slot_members[row.slotid] if mow.last_joined_at
)
slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load(
memberids=[mow.userid for mow in slot_members[row.slotid]]
)
row.closed_at = now
try:
await slot.close()
except discord.HTTPException:
pass
# Load the in-progress room data
if current_room_data:
async with voice_ignore_lock:
current_hour = now.replace(minute=0, second=0, microsecond=0)
await open_next(current_hour)
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
current_slots = [
aguild.current_slot
for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot
]
session_updates.extend(
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()),
None, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if mem.data.last_joined_at and mem.member not in slot.channel.members
)
session_updates.extend(
(mem.data.duration,
now, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if not mem.data.last_joined_at and mem.member in slot.channel.members
)
if session_updates:
accountability_members.update_many(
*session_updates,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
await asyncio.gather(
*(aguild.current_slot.start()
for aguild in AccountabilityGuild.cache.values() if aguild.current_slot)
)
else:
if session_updates:
accountability_members.update_many(
*session_updates,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# If we are in the last five minutes of the hour, open new rooms.
# Note that these may already have been opened, or they may not have been.
if now.minute >= 55:
await open_next(
now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
)
@module.launch_task
async def launch_accountability_system(client):
"""
Launcher for the accountability system.
Resumes saved sessions, and starts the accountability loop.
"""
# Load the AccountabilityGuild cache
guilds = tables.guild_config.fetch_rows_where(
accountability_category=NOTNULL
)
[AccountabilityGuild(guild.guildid) for guild in guilds]
await _accountability_system_resume()
asyncio.create_task(_accountability_loop())
async def unload_accountability(client):
"""
Save the current sessions and cancel the runloop in preparation for client shutdown.
"""
...