import asyncio import logging import traceback import datetime import discord from meta import client from settings import GuildSettings from utils.lib import FieldEnum, strfdelta, utc_now from .. import data from ..module import module class TicketType(FieldEnum): """ The possible ticket types. """ NOTE = 'NOTE', 'Note' WARNING = 'WARNING', 'Warning' STUDY_BAN = 'STUDY_BAN', 'Study Ban' MESAGE_CENSOR = 'MESSAGE_CENSOR', 'Message Censor' INVITE_CENSOR = 'INVITE_CENSOR', 'Invite Censor' class TicketState(FieldEnum): """ The possible ticket states. """ OPEN = 'OPEN', "Active" EXPIRING = 'EXPIRING', "Active" EXPIRED = 'EXPIRED', "Expired" PARDONED = 'PARDONED', "Pardoned" class Ticket: """ Abstract base class representing a Ticketed moderation action. """ # Type of event the class represents _ticket_type = None # type: TicketType _ticket_types = {} # Map: TicketType -> Ticket subclass _expiry_tasks = {} # Map: ticketid -> expiry Task def __init__(self, ticketid, *args, **kwargs): self.ticketid = ticketid @classmethod async def create(cls, *args, **kwargs): """ Method used to create a new ticket of the current type. Should add a row to the ticket table, post the ticket, and return the Ticket. """ raise NotImplementedError @property def data(self): """ Ticket row. This will usually be a row of `ticket_info`. """ return data.ticket_info.fetch(self.ticketid) @property def guild(self): return client.get_guild(self.data.guildid) @property def target(self): guild = self.guild return guild.get_member(self.data.targetid) if guild else None @property def msg_args(self): """ Ticket message posted in the moderation log. """ args = {} # Build embed info = self.data member = self.target name = str(member) if member else str(info.targetid) if info.auto: title_fmt = "Ticket #{} | {} | {}[Auto] | {}" else: title_fmt = "Ticket #{} | {} | {} | {}" title = title_fmt.format( info.guild_ticketid, TicketState(info.ticket_state).desc, TicketType(info.ticket_type).desc, name ) embed = discord.Embed( title=title, description=info.content, timestamp=info.created_at ) embed.add_field( name="Target", value="<@{}>".format(info.targetid) ) if not info.auto: embed.add_field( name="Moderator", value="<@{}>".format(info.moderator_id) ) # if info.duration: # value = "`{}` {}".format( # strfdelta(datetime.timedelta(seconds=info.duration)), # "(Expiry )".format(info.expiry.timestamp()) if info.expiry else "" # ) # embed.add_field( # name="Duration", # value=value # ) if info.expiry: if info.ticket_state == TicketState.EXPIRING: embed.add_field( name="Expires at", value="\n(Duration: `{}`)".format( info.expiry.timestamp(), strfdelta(datetime.timedelta(seconds=info.duration)) ) ) elif info.ticket_state == TicketState.EXPIRED: embed.add_field( name="Expired", value="".format( info.expiry.timestamp(), ) ) else: embed.add_field( name="Expiry", value="".format( info.expiry.timestamp() ) ) if info.context: embed.add_field( name="Context", value=info.context, inline=False ) if info.addendum: embed.add_field( name="Notes", value=info.addendum, inline=False ) if self.state == TicketState.PARDONED: embed.add_field( name="Pardoned", value=( "Pardoned by <@{}> at .\n{}" ).format( info.pardoned_by, info.pardoned_at.timestamp(), info.pardoned_reason or "" ), inline=False ) embed.set_footer(text="ID: {}".format(info.targetid)) args['embed'] = embed # Add file if info.file_name: args['file'] = discord.File(info.file_data, info.file_name) return args @property def link(self): """ The link to the ticket in the moderation log. """ info = self.data modlog = GuildSettings(info.guildid).mod_log.data return 'https://discord.com/channels/{}/{}/{}'.format( info.guildid, modlog, info.log_msg_id ) @property def state(self): return TicketState(self.data.ticket_state) async def update(self, **kwargs): """ Update ticket fields. """ fields = ( 'targetid', 'moderator_id', 'auto', 'log_msg_id', 'content', 'expiry', 'ticket_state', 'context', 'addendum', 'duration', 'file_name', 'file_data', 'pardoned_by', 'pardoned_at', 'pardoned_reason', ) params = {field: kwargs[field] for field in fields if field in kwargs} if params: data.ticket_info.update_where(params, ticketid=self.ticketid) await self.update_expiry() await self.post() async def post(self): """ Post or update the ticket in the moderation log. Also updates the saved message id. """ info = self.data modlog = GuildSettings(info.guildid).mod_log.value if not modlog: return resend = True try: if info.log_msg_id: # Try to fetch the message message = await modlog.fetch_message(info.log_msg_id) if message: if message.author.id == client.user.id: # TODO: Handle file edit await message.edit(embed=self.msg_args['embed']) resend = False else: try: await message.delete() except discord.HTTPException: pass if resend: message = await modlog.send(**self.msg_args) self.data.log_msg_id = message.id except discord.HTTPException: client.log( "Cannot post ticket (tid: {}) due to discord exception or issue.".format(self.ticketid) ) except Exception: # This should never happen in normal operation client.log( "Error while posting ticket (tid:{})! " "Exception traceback follows.\n{}".format( self.ticketid, traceback.format_exc() ), context="TICKETS", level=logging.ERROR ) @classmethod def load_expiring(cls): """ Load and schedule all expiring tickets. """ # TODO: Consider changing this to a flat timestamp system, to avoid storing lots of coroutines. # TODO: Consider only scheduling the expiries in the next day, and updating this once per day. # TODO: Only fetch tickets from guilds we are in. # Cancel existing expiry tasks for task in cls._expiry_tasks.values(): if not task.done(): task.cancel() # Get all expiring tickets expiring_rows = data.tickets.select_where( ticket_state=TicketState.EXPIRING ) # Create new expiry tasks now = utc_now() cls._expiry_tasks = { row['ticketid']: asyncio.create_task( cls._schedule_expiry_for( row['ticketid'], (row['expiry'] - now).total_seconds() ) ) for row in expiring_rows } # Log client.log( "Loaded {} expiring tickets.".format(len(cls._expiry_tasks)), context="TICKET_LOADER", ) @classmethod async def _schedule_expiry_for(cls, ticketid, delay): """ Schedule expiry for a given ticketid """ try: await asyncio.sleep(delay) ticket = Ticket.fetch(ticketid) if ticket: await asyncio.shield(ticket._expire()) except asyncio.CancelledError: return def update_expiry(self): # Cancel any existing expiry task task = self._expiry_tasks.pop(self.ticketid, None) if task and not task.done(): task.cancel() # Schedule a new expiry task, if applicable if self.data.ticket_state == TicketState.EXPIRING: self._expiry_tasks[self.ticketid] = asyncio.create_task( self._schedule_expiry_for( self.ticketid, (self.data.expiry - utc_now()).total_seconds() ) ) async def cancel_expiry(self): """ Cancel ticket expiry. In particular, may be used if another ticket overrides `self`. Sets the ticket state to `OPEN`, so that it no longer expires. """ if self.state == TicketState.EXPIRING: # Update the ticket state self.data.ticket_state = TicketState.OPEN # Remove from expiry tsks self.update_expiry() # Repost await self.post() async def _revert(self, reason=None): """ Method used to revert the ticket action, e.g. unban or remove mute role. Generally called by `pardon` and `_expire`. Must be overriden by the Ticket type, if they implement any revert logic. """ raise NotImplementedError async def _expire(self): """ Method to automatically expire a ticket. May be overriden by the Ticket type for more complex expiry logic. Must set `data.ticket_state` to `EXPIRED` if applicable. """ if self.state == TicketState.EXPIRING: client.log( "Automatically expiring ticket (tid:{}).".format(self.ticketid), context="TICKETS" ) try: await self._revert(reason="Automatic Expiry") except Exception: # This should never happen in normal operation client.log( "Error while expiring ticket (tid:{})! " "Exception traceback follows.\n{}".format( self.ticketid, traceback.format_exc() ), context="TICKETS", level=logging.ERROR ) # Update state self.data.ticket_state = TicketState.EXPIRED # Update log message await self.post() # Post a note to the modlog modlog = GuildSettings(self.data.guildid).mod_log.value if modlog: try: await modlog.send( embed=discord.Embed( colour=discord.Colour.orange(), description="[Ticket #{}]({}) expired!".format(self.data.guild_ticketid, self.link) ) ) except discord.HTTPException: pass async def pardon(self, moderator, reason, timestamp=None): """ Pardon process for the ticket. May be overidden by the Ticket type for more complex pardon logic. Must set `data.ticket_state` to `PARDONED` if applicable. """ if self.state != TicketState.PARDONED: if self.state in (TicketState.OPEN, TicketState.EXPIRING): try: await self._revert(reason="Pardoned by {}".format(moderator.id)) except Exception: # This should never happen in normal operation client.log( "Error while pardoning ticket (tid:{})! " "Exception traceback follows.\n{}".format( self.ticketid, traceback.format_exc() ), context="TICKETS", level=logging.ERROR ) # Update state with self.data.batch_update(): self.data.ticket_state = TicketState.PARDONED self.data.pardoned_at = utc_now() self.data.pardoned_by = moderator.id self.data.pardoned_reason = reason # Update (i.e. remove) expiry self.update_expiry() # Update log message await self.post() @classmethod def fetch_tickets(cls, *ticketids, **kwargs): """ Fetch tickets matching the given criteria (passed transparently to `select_where`). Positional arguments are treated as `ticketids`, which are not supported in keyword arguments. """ if ticketids: kwargs['ticketid'] = ticketids # Set the ticket type to the class type if not specified if cls._ticket_type and 'ticket_type' not in kwargs: kwargs['ticket_type'] = cls._ticket_type # This is actually mainly for caching, since we don't pass the data to the initialiser rows = data.ticket_info.fetch_rows_where( **kwargs ) return [ cls._ticket_types[TicketType(row.ticket_type)](row.ticketid) for row in rows ] @classmethod def fetch(cls, ticketid): """ Return the Ticket with the given id, if found, or `None` otherwise. """ tickets = cls.fetch_tickets(ticketid) return tickets[0] if tickets else None @classmethod def register_ticket_type(cls, ticket_cls): """ Decorator to register a new Ticket subclass as a ticket type. """ cls._ticket_types[ticket_cls._ticket_type] = ticket_cls return ticket_cls @module.launch_task async def load_expiring_tickets(client): Ticket.load_expiring()