320 lines
11 KiB
Python
320 lines
11 KiB
Python
from typing import Optional, overload, Literal
|
|
from enum import IntEnum
|
|
from collections import defaultdict
|
|
import datetime as dt
|
|
import asyncio
|
|
|
|
import discord
|
|
from cachetools import TTLCache
|
|
|
|
from utils.lib import utc_now
|
|
from meta import LionBot
|
|
from data import WeakCache
|
|
from .data import VoiceTrackerData
|
|
|
|
from . import logger
|
|
|
|
|
|
class TrackedVoiceState:
|
|
__slots__ = (
|
|
'channelid',
|
|
'video',
|
|
'stream'
|
|
)
|
|
|
|
def __init__(self, channelid: Optional[int], video: bool, stream: bool):
|
|
self.channelid = channelid
|
|
self.video = video
|
|
self.stream = stream
|
|
|
|
def __eq__(self, other: 'TrackedVoiceState'):
|
|
equal = other.channelid == self.channelid
|
|
equal = equal and other.video == self.video
|
|
equal = equal and other.stream == self.stream
|
|
|
|
def __bool__(self):
|
|
"""Whether this is an active state"""
|
|
return bool(self.channelid)
|
|
|
|
@property
|
|
def live(self):
|
|
return self.video or self.stream
|
|
|
|
@classmethod
|
|
def from_voice_state(cls, state: discord.VoiceState):
|
|
if state is not None:
|
|
return cls(
|
|
state.channel.id if state.channel else None,
|
|
state.self_video,
|
|
state.self_stream
|
|
)
|
|
else:
|
|
return cls(None, False, False)
|
|
|
|
|
|
class SessionState(IntEnum):
|
|
ONGOING = 2
|
|
PENDING = 1
|
|
INACTIVE = 0
|
|
|
|
|
|
class VoiceSession:
|
|
"""
|
|
High-level tracked voice state in the LionBot paradigm.
|
|
|
|
To ensure cache integrity and event safety,
|
|
this state may lag behind the `member.voice` obtained from Discord API.
|
|
However, the state must always match the stored state (in data).
|
|
"""
|
|
__slots__ = (
|
|
'bot',
|
|
'guildid', 'userid',
|
|
'registry',
|
|
'start_task', 'expiry_task',
|
|
'data', 'state', 'hourly_rate',
|
|
'_tag', '_start_time',
|
|
'lock',
|
|
'__weakref__'
|
|
)
|
|
|
|
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
|
|
|
# Maintains strong references to active sessions
|
|
_active_sessions_: dict[int, dict[int, 'VoiceSession']] = defaultdict(dict)
|
|
|
|
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
|
self.bot = bot
|
|
self.guildid = guildid
|
|
self.userid = userid
|
|
self.registry: VoiceTrackerData = self.bot.get_cog('VoiceTrackerCog').data
|
|
|
|
self.start_task = None # Task triggering a delayed session start
|
|
self.expiry_task = None # Task triggering a session expiry from reaching the daily cap
|
|
self.data: Optional[VoiceTrackerData.VoiceSessionsOngoing] = data # Ongoing session data
|
|
|
|
# TrackedVoiceState set when session is active
|
|
# Must match data when session in ongoing
|
|
self.state: Optional[TrackedVoiceState] = None
|
|
self.hourly_rate: Optional[float] = None
|
|
self._tag = None
|
|
self._start_time = None
|
|
|
|
# Member session lock
|
|
# Ensures state changes are atomic and serialised
|
|
self.lock = asyncio.Lock()
|
|
|
|
def cancel(self):
|
|
if self.start_task is not None:
|
|
self.start_task.cancel()
|
|
if self.expiry_task is not None:
|
|
self.expiry_task.cancel()
|
|
self._active_sessions_[self.guildid].pop(self.userid, None)
|
|
|
|
@property
|
|
def tag(self) -> Optional[str]:
|
|
if self.data:
|
|
tag = self.data.tag
|
|
else:
|
|
tag = self._tag
|
|
return tag
|
|
|
|
@property
|
|
def start_time(self):
|
|
if self.data:
|
|
start_time = self.data.start_time
|
|
else:
|
|
start_time = self._start_time
|
|
return start_time
|
|
|
|
@property
|
|
def activity(self):
|
|
if self.data is not None:
|
|
return SessionState.ONGOING
|
|
elif self.start_task is not None:
|
|
return SessionState.PENDING
|
|
else:
|
|
return SessionState.INACTIVE
|
|
|
|
@overload
|
|
@classmethod
|
|
def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[False]) -> Optional['VoiceSession']:
|
|
...
|
|
|
|
@overload
|
|
@classmethod
|
|
def get(cls, bot: LionBot, guildid: int, userid: int, create: Literal[True] = True) -> 'VoiceSession':
|
|
...
|
|
|
|
@classmethod
|
|
def get(cls, bot: LionBot, guildid: int, userid: int, create=True) -> Optional['VoiceSession']:
|
|
"""
|
|
Fetch the VoiceSession for the given member. Respects cache.
|
|
Creates the session if it doesn't already exist.
|
|
"""
|
|
session = cls._sessions_[guildid].get(userid, None)
|
|
if session is None and create:
|
|
session = cls(bot, guildid, userid)
|
|
cls._sessions_[guildid][userid] = session
|
|
return session
|
|
|
|
@classmethod
|
|
def from_ongoing(cls, bot: LionBot, data: VoiceTrackerData.VoiceSessionsOngoing, expires_at: dt.datetime):
|
|
"""
|
|
Create a VoiceSession from ongoing data and expiry time.
|
|
"""
|
|
self = cls.get(bot, data.guildid, data.userid)
|
|
if self.activity:
|
|
raise ValueError("Initialising a session which is already running!")
|
|
self.data = data
|
|
self.state = TrackedVoiceState(data.channelid, data.live_video, data.live_stream)
|
|
self.hourly_rate = data.hourly_coins
|
|
self.schedule_expiry(expires_at)
|
|
self._active_sessions_[self.guildid][self.userid] = self
|
|
return self
|
|
|
|
async def set_tag(self, new_tag):
|
|
async with self.lock:
|
|
if self.activity is SessionState.INACTIVE:
|
|
raise ValueError("Cannot set tag on an inactive voice session.")
|
|
self._tag = new_tag
|
|
if self.data is not None:
|
|
await self.data.update(tag=new_tag)
|
|
|
|
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
|
|
"""
|
|
Schedule the voice session to start at the given target time,
|
|
with the given state and hourly rate.
|
|
"""
|
|
self.state = state
|
|
self.hourly_rate = hourly_rate
|
|
self._start_time = start_time
|
|
self._tag = None
|
|
|
|
self.start_task = asyncio.create_task(self._start_after(delay, start_time))
|
|
self.schedule_expiry(expire_time)
|
|
self._active_sessions_[self.guildid][self.userid] = self
|
|
|
|
async def _start_after(self, delay: int, start_time: dt.datetime):
|
|
"""
|
|
Start a new voice session with the given state and hourly rate.
|
|
|
|
Creates the tracked_channel if required.
|
|
"""
|
|
await asyncio.sleep(delay)
|
|
|
|
async with self.lock:
|
|
logger.info(
|
|
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
|
f"and channel <cid:{self.state.channelid}>."
|
|
)
|
|
# Create the lion if required
|
|
await self.bot.core.lions.fetch_member(self.guildid, self.userid)
|
|
|
|
# Create the tracked channel if required
|
|
await self.registry.TrackedChannel.fetch_or_create(
|
|
self.state.channelid, guildid=self.guildid, deleted=False
|
|
)
|
|
|
|
# Insert an ongoing_session with the correct state, set data
|
|
state = self.state
|
|
self.data = await self.registry.VoiceSessionsOngoing.create(
|
|
guildid=self.guildid,
|
|
userid=self.userid,
|
|
channelid=state.channelid,
|
|
start_time=start_time,
|
|
last_update=start_time,
|
|
live_stream=state.stream,
|
|
live_video=state.video,
|
|
hourly_coins=self.hourly_rate,
|
|
tag=self._tag
|
|
)
|
|
self.bot.dispatch('voice_session_start', self.data)
|
|
self.start_task = None
|
|
|
|
def schedule_expiry(self, expire_time):
|
|
"""
|
|
(Re-)schedule expiry for an ongoing session.
|
|
"""
|
|
if not self.activity:
|
|
raise ValueError("Cannot schedule expiry for an inactive session!")
|
|
if self.expiry_task is not None and not self.expiry_task.done():
|
|
self.expiry_task.cancel()
|
|
|
|
delay = (expire_time - utc_now()).total_seconds()
|
|
self.expiry_task = asyncio.create_task(self._expire_after(delay))
|
|
|
|
async def _expire_after(self, delay: int):
|
|
"""
|
|
Expire a session which has exceeded the daily voice cap.
|
|
"""
|
|
# TODO: Logging, and guild logging, and user notification (?)
|
|
await asyncio.sleep(delay)
|
|
logger.info(
|
|
f"Expiring voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
|
f"and channel <cid:{self.state.channelid}>."
|
|
)
|
|
await self.close()
|
|
|
|
async def update(self, new_state: Optional[TrackedVoiceState] = None, new_rate: Optional[int] = None):
|
|
"""
|
|
Update the session state with the provided voice state or hourly rate.
|
|
Also applies to pending states.
|
|
|
|
Raises ValueError if the state does not match the saved session (i.e. wrong channel)
|
|
"""
|
|
if not self.activity:
|
|
raise ValueError("Cannot update inactive session!")
|
|
elif (new_state is not None and new_state != self.state) or (new_rate != self.hourly_rate):
|
|
if new_state is not None:
|
|
self.state = new_state
|
|
if new_rate is not None:
|
|
self.hourly_rate = new_rate
|
|
|
|
if self.data:
|
|
await self.data.update_voice_session_at(
|
|
guildid=self.guildid,
|
|
userid=self.userid,
|
|
_at=utc_now(),
|
|
stream=self.state.stream,
|
|
video=self.state.video,
|
|
rate=self.hourly_rate
|
|
)
|
|
|
|
async def close(self):
|
|
"""
|
|
Close the session, or cancel the pending session. Idempotent.
|
|
"""
|
|
async with self.lock:
|
|
if self.activity is SessionState.ONGOING:
|
|
# End the ongoing session
|
|
now = utc_now()
|
|
await self.data.close_study_session_at(self.guildid, self.userid, now)
|
|
|
|
# TODO: Something a bit saner/safer.. dispatch the finished session instead?
|
|
self.bot.dispatch('voice_session_end', self.data, now)
|
|
|
|
# Rank update
|
|
# TODO: Change to broadcasted event?
|
|
rank_cog = self.bot.get_cog('RankCog')
|
|
if rank_cog is not None:
|
|
asyncio.create_task(rank_cog.on_voice_session_complete(
|
|
(self.guildid, self.userid, int((utc_now() - self.data.start_time).total_seconds()), 0)
|
|
))
|
|
|
|
if self.start_task is not None:
|
|
self.start_task.cancel()
|
|
self.start_task = None
|
|
|
|
if self.expiry_task is not None:
|
|
self.expiry_task.cancel()
|
|
self.expiry_task = None
|
|
|
|
self.data = None
|
|
self.state = None
|
|
self.hourly_rate = None
|
|
self._tag = None
|
|
self._start_time = None
|
|
|
|
# Always release strong reference to session (to allow garbage collection)
|
|
self._active_sessions_[self.guildid].pop(self.userid)
|