rewrite: Restructure to include GUI.

This commit is contained in:
2022-12-23 06:44:32 +02:00
parent 2b93354248
commit f328324747
224 changed files with 8 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from .cog import Analytics
async def setup(bot):
await bot.add_cog(Analytics(bot))

177
src/analytics/cog.py Normal file
View File

@@ -0,0 +1,177 @@
import logging
import discord
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
from meta import LionCog, LionBot, LionContext
from meta.app import shard_talk, appname
from meta.errors import HandledException, SafeCancellation
from meta.logger import log_wrap
from utils.lib import utc_now
from .data import AnalyticsData
from .events import (
CommandStatus, CommandEvent, command_event_handler,
GuildAction, GuildEvent, guild_event_handler,
VoiceAction, VoiceEvent, voice_event_handler
)
from .snapshot import shard_snapshot
logger = logging.getLogger(__name__)
# TODO: Client side might be better handled as a single connection fed by a queue?
# Maybe consider this again after the interactive REPL idea
# Or if it seems like this is giving an absurd amount of traffic
class Analytics(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(AnalyticsData())
self.an_app = bot.config.analytics['appname']
self.talk_command_event = command_event_handler.bind(shard_talk).route
self.talk_guild_event = guild_event_handler.bind(shard_talk).route
self.talk_voice_event = voice_event_handler.bind(shard_talk).route
self.talk_shard_snapshot = shard_talk.register_route()(shard_snapshot)
async def cog_load(self):
await self.data.init()
@LionCog.listener()
@log_wrap(action='AnEvent')
async def on_voice_state_update(self, member, before, after):
if not before.channel and after.channel:
# Member joined channel
action = VoiceAction.JOINED
elif before.channel and not after.channel:
# Member left channel
action = VoiceAction.LEFT
else:
# Member change state, we don't need to deal with that
return
event = VoiceEvent(
appname=appname,
userid=member.id,
guildid=member.guild.id,
action=action,
created_at=utc_now()
)
if self.an_app not in shard_talk.peers:
logger.warning(f"Analytics peer not found, discarding event: {event}")
else:
await self.talk_voice_event(event).send(self.an_app, wait_for_reply=False)
@LionCog.listener()
@log_wrap(action='AnEvent')
async def on_guild_join(self, guild):
"""
Send guild join event.
"""
event = GuildEvent(
appname=appname,
guildid=guild.id,
action=GuildAction.JOINED,
created_at=utc_now()
)
if self.an_app not in shard_talk.peers:
logger.warning(f"Analytics peer not found, discarding event: {event}")
else:
await self.talk_guild_event(event).send(self.an_app, wait_for_reply=False)
@LionCog.listener()
@log_wrap(action='AnEvent')
async def on_guild_remove(self, guild):
"""
Send guild leave event
"""
event = GuildEvent(
appname=appname,
guildid=guild.id,
action=GuildAction.LEFT,
created_at=utc_now()
)
if self.an_app not in shard_talk.peers:
logger.warning(f"Analytics peer not found, discarding event: {event}")
else:
await self.talk_guild_event(event).send(self.an_app, wait_for_reply=False)
@LionCog.listener()
@log_wrap(action='AnEvent')
async def on_command_completion(self, ctx: LionContext):
"""
Send command completed successfully.
"""
duration = utc_now() - ctx.message.created_at
event = CommandEvent(
appname=appname,
cmdname=ctx.command.name if ctx.command else 'Unknown',
cogname=ctx.cog.qualified_name if ctx.cog else None,
userid=ctx.author.id,
created_at=utc_now(),
status=CommandStatus.COMPLETED,
execution_time=duration.total_seconds(),
guildid=ctx.guild.id if ctx.guild else None,
ctxid=ctx.message.id
)
if self.an_app not in shard_talk.peers:
logger.warning(f"Analytics peer not found, discarding event: {event}")
else:
await self.talk_command_event(event).send(self.an_app, wait_for_reply=False)
@LionCog.listener()
@log_wrap(action='AnEvent')
async def on_command_error(self, ctx: LionContext, error):
"""
Send command failed.
"""
duration = utc_now() - ctx.message.created_at
status = CommandStatus.FAILED
err_type = None
try:
err_type = repr(error)
raise error
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
original = error.original
try:
err_type = repr(original)
if isinstance(original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
raise original.original
else:
raise original
except HandledException:
status = CommandStatus.CANCELLED
except SafeCancellation:
status = CommandStatus.CANCELLED
except discord.Forbidden:
status = CommandStatus.CANCELLED
except discord.HTTPException:
status = CommandStatus.CANCELLED
except Exception:
status = CommandStatus.FAILED
except CheckFailure:
status = CommandStatus.CANCELLED
except Exception:
status = CommandStatus.FAILED
event = CommandEvent(
appname=appname,
cmdname=ctx.command.name if ctx.command else 'Unknown',
cogname=ctx.cog.qualified_name if ctx.cog else None,
userid=ctx.author.id,
created_at=utc_now(),
status=status,
error=err_type,
execution_time=duration.total_seconds(),
guildid=ctx.guild.id if ctx.guild else None,
ctxid=ctx.message.id
)
if self.an_app not in shard_talk.peers:
logger.warning(f"Analytics peer not found, discarding event: {event}")
else:
await self.talk_command_event(event).send(self.an_app, wait_for_reply=False)

189
src/analytics/data.py Normal file
View File

@@ -0,0 +1,189 @@
from enum import Enum
from data.registry import Registry
from data.adapted import RegisterEnum
from data.models import RowModel
from data.columns import Integer, String, Timestamp, Column
class CommandStatus(Enum):
"""
Schema
------
CREATE TYPE analytics.CommandStatus AS ENUM(
'COMPLETED',
'CANCELLED'
'FAILED'
);
"""
COMPLETED = ('COMPLETED',)
CANCELLED = ('CANCELLED',)
FAILED = ('FAILED',)
class GuildAction(Enum):
"""
Schema
------
CREATE TYPE analytics.GuildAction AS ENUM(
'JOINED',
'LEFT'
);
"""
JOINED = ('JOINED',)
LEFT = ('LEFT',)
class VoiceAction(Enum):
"""
Schema
------
CREATE TYPE analytics.VoiceAction AS ENUM(
'JOINED',
'LEFT'
);
"""
JOINED = ('JOINED',)
LEFT = ('LEFT',)
class AnalyticsData(Registry, name='analytics'):
CommandStatus = RegisterEnum(CommandStatus, name="analytics.CommandStatus")
GuildAction = RegisterEnum(GuildAction, name="analytics.GuildAction")
VoiceAction = RegisterEnum(VoiceAction, name="analytics.VoiceAction")
class Snapshots(RowModel):
"""
Schema
------
CREATE TABLE analytics.snapshots(
snapshotid SERIAL PRIMARY KEY,
appname TEXT NOT NULL REFERENCES bot_config (appname),
guild_count INTEGER NOT NULL,
member_count INTEGER NOT NULL,
user_count INTEGER NOT NULL,
in_voice INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
);
"""
_schema_ = 'analytics'
_tablename_ = 'snapshots'
snapshotid = Integer(primary=True)
appname = String()
guild_count = Integer()
member_count = Integer()
user_count = Integer()
in_voice = Integer()
created_at = Timestamp()
class Events(RowModel):
"""
Schema
------
CREATE TABLE analytics.events(
eventid SERIAL PRIMARY KEY,
appname TEXT NOT NULL REFERENCES bot_config (appname),
ctxid BIGINT,
guildid BIGINT,
_created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
);
"""
_schema_ = 'analytics'
_tablename_ = 'events'
eventid = Integer(primary=True)
appname = String()
ctxid = Integer()
guildid = Integer()
created_at = Timestamp()
class Commands(RowModel):
"""
Schema
------
CREATE TABLE analytics.commands(
cmdname TEXT NOT NULL,
cogname TEXT,
userid BIGINT NOT NULL,
status analytics.CommandStatus NOT NULL,
execution_time REAL NOT NULL
) INHERITS (analytics.events);
"""
_schema_ = 'analytics'
_tablename_ = 'commands'
eventid = Integer(primary=True)
appname = String()
ctxid = Integer()
guildid = Integer()
created_at = Timestamp()
cmdname = String()
cogname = String()
userid = Integer()
status: Column[CommandStatus] = Column()
error = String()
execution_time: Column[float] = Column()
class Guilds(RowModel):
"""
Schema
------
CREATE TABLE analytics.guilds(
guildid BIGINT NOT NULL,
action analytics.GuildAction NOT NULL
) INHERITS (analytics.events);
"""
_schema_ = 'analytics'
_tablename_ = 'guilds'
eventid = Integer(primary=True)
appname = String()
ctxid = Integer()
guildid = Integer()
created_at = Timestamp()
action: Column[GuildAction] = Column()
class VoiceSession(RowModel):
"""
Schema
------
CREATE TABLE analytics.voice_sessions(
userid BIGINT NOT NULL,
action analytics.VoiceAction NOT NULL
) INHERITS (analytics.events);
"""
_schema_ = 'analytics'
_tablename_ = 'voice_sessions'
eventid = Integer(primary=True)
appname = String()
ctxid = Integer()
guildid = Integer()
created_at = Timestamp()
userid = Integer()
action: Column[GuildAction] = Column()
class GuiRender(RowModel):
"""
Schema
------
CREATE TABLE analytics.gui_renders(
cardname TEXT NOT NULL,
duration INTEGER NOT NULL
) INHERITS (analytics.events);
"""
_schema_ = 'analytics'
_tablename_ = 'gui_renders'
eventid = Integer(primary=True)
appname = String()
ctxid = Integer()
guildid = Integer()
created_at = Timestamp()
cardname = String()
duration = Integer()

180
src/analytics/events.py Normal file
View File

@@ -0,0 +1,180 @@
import asyncio
import datetime
import logging
from collections import namedtuple
from typing import NamedTuple, Optional, Generic, Type, TypeVar
from meta.ipc import AppRoute, AppClient
from meta.logger import logging_context, log_wrap
from data import RowModel
from .data import AnalyticsData, CommandStatus, VoiceAction, GuildAction
logger = logging.getLogger(__name__)
"""
TODO
Snapshot type? Incremental or manual?
Request snapshot route will require all shards to be online
Update batch size before release, or put it in the config
"""
T = TypeVar('T')
class EventHandler(Generic[T]):
def __init__(self, route_name: str, model: Type[RowModel], struct: Type[T], batchsize: int = 20):
self.model = model
self.struct = struct
self.batch_size = batchsize
self.route_name = route_name
self._route: Optional[AppRoute] = None
self._client: Optional[AppClient] = None
self.queue: asyncio.Queue[T] = asyncio.Queue()
self.batch: list[T] = []
self._consumer_task: Optional[asyncio.Task] = None
@property
def route(self):
if self._route is None:
self._route = AppRoute(self.handle_event, name=self.route_name)
return self._route
async def handle_event(self, data):
try:
await self.queue.put(data)
except asyncio.QueueFull:
logger.warning(
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
)
async def consumer(self):
with logging_context(action='consumer'):
while True:
try:
item = await self.queue.get()
self.batch.append(item)
if len(self.batch) > self.batch_size:
await self.process_batch()
except asyncio.CancelledError:
# Try and process the last batch
logger.info(
f"Event handler {self.route_name} received cancellation signal! "
"Trying to process last batch."
)
if self.batch:
await self.process_batch()
raise
except Exception:
logger.exception(
f"Event handler {self.route_name} received unhandled error."
" Ignoring and continuing cautiously."
)
pass
async def process_batch(self):
with logging_context(action='batch'):
logger.debug("Processing Batch")
# TODO: copy syntax might be more efficient here
await self.model.table.insert_many(
self.struct._fields,
*map(tuple, self.batch)
)
self.batch.clear()
def bind(self, client: AppClient):
"""
Bind our route to the given client.
"""
if self._client:
raise ValueError("This EventHandler is already attached!")
self._client = client
self.route._client = client
client.routes[self.route_name] = self.route
return self
def unbind(self):
"""
Unbind from the client.
"""
if not self._client:
raise ValueError("Not attached, cannot detach!")
self._client.routes.pop(self.route_name, None)
self._route = None
logger.info(
f"EventHandler {self.route_name} has attached to the ShardTalk client."
)
return self
async def attach(self, client: AppClient):
"""
Attach to a ShardTalk client and start listening.
"""
with logging_context(action=self.route_name):
self.bind(client)
self._consumer_task = asyncio.create_task(self.consumer())
logger.info(
f"EventHandler {self.route_name} is listening for incoming events."
)
return self
async def detach(self):
"""
Stop listening and detach from client.
"""
self.unbind()
if self._consumer_task and not self._consumer_task.done():
self._consumer_task.cancel()
self._consumer_task = None
logger.info(
f"EventHandler {self.route_name} has detached."
)
return self
class CommandEvent(NamedTuple):
appname: str
cmdname: str
userid: int
created_at: datetime.datetime
status: CommandStatus
execution_time: float
error: Optional[str] = None
cogname: Optional[str] = None
guildid: Optional[int] = None
ctxid: Optional[int] = None
command_event_handler: EventHandler[CommandEvent] = EventHandler(
'command_event', AnalyticsData.Commands, CommandEvent, batchsize=1
)
class GuildEvent(NamedTuple):
appname: str
guildid: int
action: GuildAction
created_at: datetime.datetime
guild_event_handler: EventHandler[GuildEvent] = EventHandler(
'guild_event', AnalyticsData.Guilds, GuildEvent, batchsize=0
)
class VoiceEvent(NamedTuple):
appname: str
guildid: int
userid: int
action: VoiceAction
created_at: datetime.datetime
voice_event_handler: EventHandler[VoiceEvent] = EventHandler(
'voice_event', AnalyticsData.VoiceSession, VoiceEvent, batchsize=5
)

128
src/analytics/server.py Normal file
View File

@@ -0,0 +1,128 @@
import asyncio
import logging
from typing import Optional
from meta import conf, appname
from meta.logger import log_context, log_action_stack, logging_context, log_app, log_wrap
from meta.ipc import AppClient
from meta.app import appname_from_shard
from meta.sharding import shard_count
from data import Database
from .events import command_event_handler, guild_event_handler, voice_event_handler
from .snapshot import shard_snapshot, ShardSnapshot
from .data import AnalyticsData
logger = logging.getLogger(__name__)
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
logging.getLogger(name).setLevel(conf.logging_levels[name])
class AnalyticsServer:
# TODO: Move these to the config
# How often to request snapshots
snap_period = 120
# How soon after a snapshot failure (e.g. not all shards online) to retry
snap_retry_period = 10
def __init__(self) -> None:
self.db = Database(conf.data['args'])
self.data = self.db.load_registry(AnalyticsData())
self.event_handlers = [
command_event_handler,
guild_event_handler,
voice_event_handler
]
self.talk = AppClient(
conf.analytics['appname'],
appname,
{'host': conf.analytics['server_host'], 'port': int(conf.analytics['server_port'])},
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
)
self.talk_shard_snapshot = self.talk.register_route()(shard_snapshot)
self._snap_task: Optional[asyncio.Task] = None
async def attach_event_handlers(self):
for handler in self.event_handlers:
await handler.attach(self.talk)
@log_wrap(action='Snap')
async def take_snapshot(self):
# Check if all the shards are registered on shard_talk
expected_peers = [appname_from_shard(i) for i in range(0, shard_count)]
if missing := [peer for peer in expected_peers if peer not in self.talk.peers]:
# We are missing peer(s)!
logger.warning(
f"Analytics could not take snapshot because peers are missing: {', '.join(missing)}"
)
return False
# Everyone is here, ask for shard snapshots
results = await self.talk_shard_snapshot().broadcast()
# Make sure everyone sent results and there were no exceptions (e.g. concurrency)
if not all(result is not None and not isinstance(result, Exception) for result in results.values()):
# This should essentially never happen
# Either some of the shards could not make a snapshot (e.g. Discord client issues)
# or they disconnected in the process.
logger.warning(
f"Analytics could not take snapshot because some peers failed! Partial snapshot: {results}"
)
return False
# Now we have a dictionary of shard snapshots, aggregate, pull in remaining data, and store.
# TODO Possibly move this out into snapshots.py?
aggregate = {field: 0 for field in ShardSnapshot._fields}
for result in results.values():
for field, num in result._asdict().items():
aggregate[field] += num
row = await self.data.Snapshots.create(
appname=appname,
guild_count=aggregate['guild_count'],
member_count=aggregate['member_count'],
user_count=aggregate['user_count'],
in_voice=aggregate['voice_count'],
)
logger.info(f"Created snapshot: {row.data!r}")
return True
@log_wrap(action='SnapLoop')
async def snapshot_loop(self):
while True:
try:
result = await self.take_snapshot()
if result:
await asyncio.sleep(self.snap_period)
else:
logger.info("Snapshot failed, retrying after %d seconds", self.snap_retry_period)
await asyncio.sleep(self.snap_retry_period)
except asyncio.CancelledError:
logger.info("Snapshot loop cancelled, closing.")
return
except Exception:
logger.exception(
"Unhandled exception during snapshot loop. Ignoring and continuing cautiously."
)
await asyncio.sleep(self.snap_retry_period)
async def run(self):
log_action_stack.set(['Analytics'])
log_app.set(conf.analytics['appname'])
async with await self.db.connect():
await self.talk.connect()
await self.attach_event_handlers()
self._snap_task = asyncio.create_task(self.snapshot_loop())
await asyncio.gather(*(handler._consumer_task for handler in self.event_handlers))
if __name__ == '__main__':
server = AnalyticsServer()
asyncio.run(server.run())

28
src/analytics/snapshot.py Normal file
View File

@@ -0,0 +1,28 @@
from typing import NamedTuple
from meta.context import ctx_bot
class ShardSnapshot(NamedTuple):
guild_count: int
voice_count: int
member_count: int
user_count: int
async def shard_snapshot():
"""
Take a snapshot of the current shard.
"""
bot = ctx_bot.get()
if bot is None or not bot.is_ready():
# We cannot take a snapshot without Bot
# Just quietly fail
return None
snap = ShardSnapshot(
guild_count=len(bot.guilds),
voice_count=sum(len(channel.members) for guild in bot.guilds for channel in guild.voice_channels),
member_count=sum(len(guild.members) for guild in bot.guilds),
user_count=len(set(m.id for guild in bot.guilds for m in guild.members))
)
return snap

6
src/babel/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator
async def setup(bot):
from .cog import BabelCog
await bot.add_cog(BabelCog(bot))

300
src/babel/cog.py Normal file
View File

@@ -0,0 +1,300 @@
"""
Babel Cog.
Calculates and sets current locale before command runs (via check_once).
Also defines the relevant guild and user settings for localisation.
"""
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from settings import ModelData
from settings.setting_types import StringSetting, BoolSetting
from settings.groups import SettingGroup
from core.data import CoreData
from .translator import ctx_locale, ctx_translator, LocalBabel, SOURCE_LOCALE
babel = LocalBabel('babel')
_ = babel._
_p = babel._p
class LocaleSettings(SettingGroup):
class UserLocale(ModelData, StringSetting):
"""
User-configured locale.
Exposed via dedicated setting command.
"""
setting_id = 'user_locale'
display_name = _p('userset:locale', 'language')
desc = _p('userset:locale|desc', "Your preferred language for interacting with me.")
_model = CoreData.User
_column = CoreData.User.locale.name
@property
def update_message(self):
t = ctx_translator.get().t
if self.data is None:
return t(_p('userset:locale|response', "You have unset your language."))
else:
return t(_p('userset:locale|response', "You have set your language to `{lang}`.")).format(
lang=self.data
)
@classmethod
async def _parse_string(cls, parent_id, string, **kwargs):
translator = ctx_translator.get()
if string not in translator.supported_locales:
lang = string[:20]
raise UserInputError(
translator.t(
_p('userset:locale|error', "Sorry, we do not support the `{lang}` language at this time!")
).format(lang=lang)
)
return string
class ForceLocale(ModelData, BoolSetting):
"""
Guild configuration for whether to force usage of the guild locale.
Exposed via `/configure language` command and standard configuration interface.
"""
setting_id = 'force_locale'
display_name = _p('guildset:force_locale', 'force_language')
desc = _p('guildset:force_locale|desc',
"Whether to force all members to use the configured guild language when interacting with me.")
long_desc = _p(
'guildset:force_locale|long_desc',
"When enabled, commands in this guild will always use the configured guild language, "
"regardless of the member's personally configured language."
)
_outputs = {
True: _p('guildset:force_locale|output', 'Enabled (members will be forced to use the server language)'),
False: _p('guildset:force_locale|output', 'Disabled (members may set their own language)'),
None: 'Not Set' # This should be impossible, since we have a default
}
_default = False
_model = CoreData.Guild
_column = CoreData.Guild.force_locale.name
@property
def update_message(self):
t = ctx_translator.get().t
if self.data:
return t(_p(
'guildset:force_locale|response',
"I will always use the set language in this server."
))
else:
return t(_p(
'guildset:force_locale|response',
"I will now allow the members to set their own language here."
))
class GuildLocale(ModelData, StringSetting):
"""
Guild-configured locale.
Exposed via `/configure language` command, and standard configuration interface.
"""
setting_id = 'guild_locale'
display_name = _p('guildset:locale', 'language')
desc = _p('guildset:locale|desc', "Your preferred language for interacting with me.")
_model = CoreData.Guild
_column = CoreData.Guild.locale.name
@property
def update_message(self):
t = ctx_translator.get().t
if self.data is None:
return t(_p('guildset:locale|response', "You have reset the guild language."))
else:
return t(_p('guildset:locale|response', "You have set the guild language to `{lang}`.")).format(
lang=self.data
)
@classmethod
async def _parse_string(cls, parent_id, string, **kwargs):
translator = ctx_translator.get()
if string not in translator.supported_locales:
lang = string[:20]
raise UserInputError(
translator.t(
_p('guildset:locale|error', "Sorry, we do not support the `{lang}` language at this time!")
).format(lang=lang)
)
return string
class BabelCog(LionCog):
depends = {'CoreCog'}
def __init__(self, bot: LionBot):
self.bot = bot
self.settings = LocaleSettings()
self.t = self.bot.translator.t
async def cog_load(self):
if not self.bot.core:
raise ValueError("CoreCog must be loaded first!")
self.bot.core.guild_settings.attach(LocaleSettings.ForceLocale)
self.bot.core.guild_settings.attach(LocaleSettings.GuildLocale)
self.bot.core.user_settings.attach(LocaleSettings.UserLocale)
async def cog_unload(self):
pass
async def get_user_locale(self, userid):
"""
Fetch the best locale we can guess for this userid.
"""
data = await self.bot.core.data.User.fetch(userid)
if data:
return data.locale or data.locale_hint or SOURCE_LOCALE
else:
return SOURCE_LOCALE
async def bot_check_once(self, ctx: LionContext): # type: ignore # Type checker doesn't understand coro checks
"""
Calculate and inject the current locale before the command begins.
Locale resolution is calculated as follows:
If the guild has force_locale enabled, and a locale set,
then the guild's locale will be used.
Otherwise, the priority is
user_locale -> command_locale -> user_locale_hint -> guild_locale -> default_locale
"""
locale = None
if ctx.guild:
forced = ctx.alion.guild_settings['force_locale'].value
guild_locale = ctx.alion.guild_settings['guild_locale'].value
if forced:
locale = guild_locale
locale = locale or ctx.alion.user_settings['user_locale'].value
if ctx.interaction:
locale = locale or ctx.interaction.locale.value
if ctx.guild:
locale = locale or guild_locale
locale = locale or SOURCE_LOCALE
ctx_locale.set(locale)
ctx_translator.set(self.bot.translator)
return True
@cmds.hybrid_command(
name=LocaleSettings.UserLocale.display_name,
description=LocaleSettings.UserLocale.desc
)
async def cmd_language(self, ctx: LionContext, language: str):
"""
Dedicated user setting command for the `locale` setting.
"""
if not ctx.interaction:
# This command is not available as a text command
return
setting = await self.settings.UserLocale.get(ctx.author.id)
new_data = await setting._parse_string(ctx.author.id, language)
await setting.interactive_set(new_data, ctx.interaction)
@cmds.hybrid_command(
name=_p('cmd:configure_language', "configure_language"),
description=_p('cmd:configure_language|desc',
"Configure the default language I will use in this server.")
)
@appcmds.choices(
force_language=[
appcmds.Choice(name=LocaleSettings.ForceLocale._outputs[True], value=1),
appcmds.Choice(name=LocaleSettings.ForceLocale._outputs[False], value=0),
]
)
@appcmds.guild_only() # Can be removed when attached as a subcommand
async def cmd_configure_language(
self, ctx: LionContext, language: Optional[str] = None, force_language: Optional[appcmds.Choice[int]] = None
):
if not ctx.interaction:
# This command is not available as a text command
return
if not ctx.guild:
# This is impossible by decorators, but adding this guard for the type checker
return
t = self.t
# TODO: Setting group, and group setting widget
# We can attach the command to the setting group as an application command
# Then load it into the configure command group dynamically
lang_setting = await self.settings.GuildLocale.get(ctx.guild.id)
force_setting = await self.settings.ForceLocale.get(ctx.guild.id)
if language:
lang_data = await lang_setting._parse_string(ctx.guild.id, language)
if force_language is not None:
force_data = bool(force_language)
if force_language is not None and not (lang_data if language is not None else lang_setting.value):
# Setting force without having a language!
raise UserInputError(
t(_p(
'cmd:configure_language|error',
"You cannot enable `{force_setting}` without having a configured language!"
)).format(force_setting=t(LocaleSettings.ForceLocale.display_name))
)
# TODO: Really need simultaneous model writes, or batched writes
lines = []
if language:
lang_setting.data = lang_data
await lang_setting.write()
lines.append(lang_setting.update_message)
if force_language is not None:
force_setting.data = force_data
await force_setting.write()
lines.append(force_setting.update_message)
result = '\n'.join(
f"{self.bot.config.emojis.tick} {line}" for line in lines
)
# TODO: Setting group widget
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.green(),
title=t(_p('cmd:configure_language|success', "Language settings updated!")),
description=result
)
)
@cmd_configure_language.autocomplete('language')
async def cmd_configure_language_acmpl_language(self, interaction: discord.Interaction, partial: str):
# TODO: More friendly language names
supported = self.bot.translator.supported_locales
matching = [lang for lang in supported if partial.lower() in lang]
t = self.t
if not matching:
return [
appcmds.Choice(
name=t(_p(
'cmd:configure_language|acmpl:language',
"No supported languages matching {partial}"
)).format(partial=partial),
value='None'
)
]
else:
return [
appcmds.Choice(name=lang, value=lang)
for lang in matching
]

34
src/babel/enums.py Normal file
View File

@@ -0,0 +1,34 @@
from enum import Enum
class LocaleMap(Enum):
american_english = 'en-US'
british_english = 'en-GB'
bulgarian = 'bg'
chinese = 'zh-CN'
taiwan_chinese = 'zh-TW'
croatian = 'hr'
czech = 'cs'
danish = 'da'
dutch = 'nl'
finnish = 'fi'
french = 'fr'
german = 'de'
greek = 'el'
hindi = 'hi'
hungarian = 'hu'
italian = 'it'
japanese = 'ja'
korean = 'ko'
lithuanian = 'lt'
norwegian = 'no'
polish = 'pl'
brazil_portuguese = 'pt-BR'
romanian = 'ro'
russian = 'ru'
spain_spanish = 'es-ES'
swedish = 'sv-SE'
thai = 'th'
turkish = 'tr'
ukrainian = 'uk'
vietnamese = 'vi'

157
src/babel/translator.py Normal file
View File

@@ -0,0 +1,157 @@
import gettext
import logging
from contextvars import ContextVar
from collections import defaultdict
from enum import Enum
from discord.app_commands import Translator, locale_str
from discord.enums import Locale
logger = logging.getLogger(__name__)
SOURCE_LOCALE = 'en_uk'
ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE)
ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore
null = gettext.NullTranslations()
class LeoBabel(Translator):
def __init__(self):
self.supported_locales = {loc.name for loc in Locale}
self.supported_domains = {}
self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator
def read_supported(self):
"""
Load supported localisations and domains from the config.
"""
from meta import conf
locales = conf.babel.get('locales', '')
stripped = (loc.strip(', ') for loc in locales.split(','))
self.supported_locales = {loc for loc in stripped if loc}
domains = conf.babel.get('domains', '')
stripped = (dom.strip(', ') for dom in domains.split(','))
self.supported_domains = {dom for dom in stripped if dom}
async def load(self):
"""
Initialise the gettext translators for the supported_locales.
"""
self.read_supported()
for locale in self.supported_locales:
for domain in self.supported_domains:
if locale == SOURCE_LOCALE:
continue
try:
translator = gettext.translation(domain, "locales/", languages=[locale])
except OSError:
# Presume translation does not exist
logger.warning(f"Could not load translator for supported <locale: {locale}> <domain: {domain}>")
pass
else:
logger.debug(f"Loaded translator for <locale: {locale}> <domain: {domain}>")
self.translators[locale][domain] = translator
async def unload(self):
self.translators.clear()
def get_translator(self, locale, domain):
if locale == SOURCE_LOCALE:
return null
translator = self.translators[locale].get(domain, None)
if translator is None:
logger.warning(
f"Translator missing for requested <locale: {locale}> and <domain: {domain}>. Setting NullTranslator."
)
self.translators[locale][domain] = null
translator = null
return translator
def t(self, lazystr, locale=None):
domain = lazystr.domain
translator = self.get_translator(locale or lazystr.locale, domain)
return lazystr._translate_with(translator)
async def translate(self, string: locale_str, locale: Locale, context):
if locale.value in self.supported_locales:
domain = string.extras.get('domain', None)
if domain is None and isinstance(string, LazyStr):
logger.debug(
f"LeoBabel cannot translate a locale_str with no domain set. Context: {context}, String: {string}"
)
return None
translator = self.get_translator(locale.value, domain)
if not isinstance(string, LazyStr):
lazy = LazyStr(Method.GETTEXT, string.message)
else:
lazy = string
return lazy._translate_with(translator)
class Method(Enum):
GETTEXT = 'gettext'
NGETTEXT = 'ngettext'
PGETTEXT = 'pgettext'
NPGETTEXT = 'npgettext'
class LocalBabel:
def __init__(self, domain):
self.domain = domain
@property
def methods(self):
return (self._, self._n, self._p, self._np)
def _(self, message):
return LazyStr(Method.GETTEXT, message, domain=self.domain)
def _n(self, singular, plural, n):
return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain)
def _p(self, context, message):
return LazyStr(Method.PGETTEXT, context, message, domain=self.domain)
def _np(self, context, singular, plural, n):
return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain)
class LazyStr(locale_str):
__slots__ = ('method', 'args', 'domain', 'locale')
def __init__(self, method, *args, locale=None, domain=None):
self.method = method
self.args = args
self.domain = domain
self.locale = locale or ctx_locale.get()
@property
def message(self):
return self._translate_with(null)
@property
def extras(self):
return {'locale': self.locale, 'domain': self.domain}
def __str__(self):
return self.message
def _translate_with(self, translator: gettext.GNUTranslations):
method = getattr(translator, self.method.value)
return method(*self.args)
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})'
def __eq__(self, obj: object) -> bool:
return isinstance(obj, locale_str) and self.message == obj.message
def __hash__(self) -> int:
return hash(self.args)

87
src/bot.py Normal file
View File

@@ -0,0 +1,87 @@
import asyncio
import logging
import aiohttp
import discord
from discord.ext import commands
from meta import LionBot, conf, sharding, appname, shard_talk
from meta.app import shardname
from meta.logger import log_context, log_action_stack, logging_context
from meta.context import ctx_bot
from data import Database
from babel.translator import LeoBabel, ctx_translator
from constants import DATA_VERSION
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
logging.getLogger(name).setLevel(conf.logging_levels[name])
logger = logging.getLogger(__name__)
db = Database(conf.data['args'])
async def main():
log_action_stack.set(["Initialising"])
logger.info("Initialising StudyLion")
intents = discord.Intents.all()
intents.members = True
intents.message_content = True
async with await db.connect():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
translator = LeoBabel()
ctx_translator.set(translator)
async with aiohttp.ClientSession() as session:
async with LionBot(
command_prefix=commands.when_mentioned,
intents=intents,
appname=appname,
shardname=shardname,
db=db,
config=conf,
initial_extensions=['utils', 'core', 'analytics', 'babel', 'modules'],
web_client=session,
app_ipc=shard_talk,
testing_guilds=conf.bot.getintlist('admin_guilds'),
shard_id=sharding.shard_number,
shard_count=sharding.shard_count,
translator=translator
) as lionbot:
ctx_bot.set(lionbot)
try:
with logging_context(context=f"APP: {appname}"):
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
await lionbot.start(conf.bot['TOKEN'])
except asyncio.CancelledError:
with logging_context(context=f"APP: {appname}", action="Shutting Down"):
logger.info("StudyLion closed, shutting down.", exc_info=True)
def _main():
from signal import SIGINT, SIGTERM
loop = asyncio.get_event_loop()
main_task = asyncio.ensure_future(main())
for signal in [SIGINT, SIGTERM]:
loop.add_signal_handler(signal, main_task.cancel)
try:
loop.run_until_complete(main_task)
finally:
loop.close()
logging.shutdown()
if __name__ == '__main__':
_main()

4
src/constants.py Normal file
View File

@@ -0,0 +1,4 @@
CONFIG_FILE = "config/bot.conf"
DATA_VERSION = 13
MAX_COINS = 2147483647 - 1

5
src/core/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .cog import CoreCog
async def setup(bot):
await bot.add_cog(CoreCog(bot))

89
src/core/cog.py Normal file
View File

@@ -0,0 +1,89 @@
from typing import Optional
import discord
from meta import LionBot, LionCog, LionContext
from meta.app import shardname, appname
from meta.logger import log_wrap
from utils.lib import utc_now
from settings.groups import SettingGroup
from .data import CoreData
from .lion import Lions
from .guild_settings import GuildSettings
from .user_settings import UserSettings
class CoreCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = CoreData()
bot.db.load_registry(self.data)
self.lions = Lions(bot)
self.app_config: Optional[CoreData.AppConfig] = None
self.bot_config: Optional[CoreData.BotConfig] = None
self.shard_data: Optional[CoreData.Shard] = None
# Some global setting registries
# Do not use these for direct setting access
# Instead, import the setting directly or use the cog API
self.bot_setting_groups: list[SettingGroup] = []
self.guild_setting_groups: list[SettingGroup] = []
self.user_setting_groups: list[SettingGroup] = []
# Some ModelSetting registries
# These are for more convenient direct access
self.guild_settings = GuildSettings
self.user_settings = UserSettings
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
async def bot_check_once(self, ctx: LionContext): # type: ignore
lion = await self.lions.fetch(ctx.guild.id if ctx.guild else 0, ctx.author.id)
if ctx.guild:
await lion.touch_discord_models(ctx.author) # type: ignore # Type checker doesn't recognise guard
ctx.alion = lion
return True
async def cog_load(self):
# Fetch (and possibly create) core data rows.
conn = await self.bot.db.get_connection()
async with conn.transaction():
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.shard_data = await self.data.Shard.fetch_or_create(
shardname,
appname=appname,
shard_id=self.bot.shard_id,
shard_count=self.bot.shard_count
)
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')
self.bot.core = self
await self.bot.add_cog(self.lions)
# Load the app command cache
for guildid in self.bot.testing_guilds:
self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid))
self.app_cmd_cache += await self.bot.tree.fetch_commands()
self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache}
async def cog_unload(self):
await self.bot.remove_cog(self.lions.qualified_name)
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_join')
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_leave')
self.bot.core = None
@LionCog.listener('on_ready')
@log_wrap(action='Touch shard data')
async def touch_shard_data(self):
# Update the last login and guild count for this shard
await self.shard_data.update(last_login=utc_now(), guild_count=len(self.bot.guilds))
@log_wrap(action='Update shard guilds')
async def shard_update_guilds(self, guild):
await self.shard_data.update(guild_count=len(self.bot.guilds))

315
src/core/data.py Normal file
View File

@@ -0,0 +1,315 @@
from itertools import chain
from psycopg import sql
from cachetools import TTLCache
from data import Table, Registry, Column, RowModel
from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp
class CoreData(Registry, name="core"):
class AppConfig(RowModel):
"""
Schema
------
CREATE TABLE app_config(
appname TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
_tablename_ = 'app_config'
appname = String(primary=True)
created_at = Timestamp()
class BotConfig(RowModel):
"""
Schema
------
CREATE TABLE bot_config(
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
default_skin TEXT
);
"""
_tablename_ = 'bot_config'
appname = String(primary=True)
default_skin = String()
class Shard(RowModel):
"""
Schema
------
CREATE TABLE shard_data(
shardname TEXT PRIMARY KEY,
appname TEXT REFERENCES bot_config(appname) ON DELETE CASCADE,
shard_id INTEGER NOT NULL,
shard_count INTEGER NOT NULL,
last_login TIMESTAMPTZ,
guild_count INTEGER
);
"""
_tablename_ = 'shard_data'
shardname = String(primary=True)
appname = String()
shard_id = Integer()
shard_count = Integer()
last_login = Timestamp()
guild_count = Integer()
class User(RowModel):
"""
User model, representing configuration data for a single user.
Schema
------
CREATE TABLE user_config(
userid BIGINT PRIMARY KEY,
timezone TEXT,
topgg_vote_reminder BOOLEAN,
avatar_hash TEXT,
name TEXT,
API_timestamp BIGINT,
gems INTEGER DEFAULT 0,
first_seen TIMESTAMPTZ DEFAULT now(),
last_seen TIMESTAMPTZ,
locale TEXT,
locale_hint TEXT
);
"""
_tablename_ = "user_config"
_cache_: WeakCache[tuple[int], 'CoreData.User'] = WeakCache(TTLCache(1000, ttl=60*5))
userid = Integer(primary=True)
timezone = String()
topgg_vote_reminder = Bool()
avatar_hash = String()
name = String()
API_timestamp = Integer()
gems = Integer()
first_seen = Timestamp()
last_seen = Timestamp()
locale = String()
locale_hint = String()
class Guild(RowModel):
"""
Guild model, representing configuration data for a single guild.
Schema
------
CREATE TABLE guild_config(
guildid BIGINT PRIMARY KEY,
admin_role BIGINT,
mod_role BIGINT,
event_log_channel BIGINT,
mod_log_channel BIGINT,
alert_channel BIGINT,
studyban_role BIGINT,
min_workout_length INTEGER,
workout_reward INTEGER,
max_tasks INTEGER,
task_reward INTEGER,
task_reward_limit INTEGER,
study_hourly_reward INTEGER,
study_hourly_live_bonus INTEGER,
renting_price INTEGER,
renting_category BIGINT,
renting_cap INTEGER,
renting_role BIGINT,
renting_sync_perms BOOLEAN,
accountability_category BIGINT,
accountability_lobby BIGINT,
accountability_bonus INTEGER,
accountability_reward INTEGER,
accountability_price INTEGER,
video_studyban BOOLEAN,
video_grace_period INTEGER,
greeting_channel BIGINT,
greeting_message TEXT,
returning_message TEXT,
starting_funds INTEGER,
persist_roles BOOLEAN,
daily_study_cap INTEGER,
pomodoro_channel BIGINT,
name TEXT,
first_joined_at TIMESTAMPTZ DEFAULT now(),
left_at TIMESTAMPTZ,
locale TEXT,
force_locale BOOLEAN
);
"""
_tablename_ = "guild_config"
_cache_: WeakCache[tuple[int], 'CoreData.Guild'] = WeakCache(TTLCache(1000, ttl=60*5))
guildid = Integer(primary=True)
admin_role = Integer()
mod_role = Integer()
event_log_channel = Integer()
mod_log_channel = Integer()
alert_channel = Integer()
studyban_role = Integer()
max_study_bans = Integer()
min_workout_length = Integer()
workout_reward = Integer()
max_tasks = Integer()
task_reward = Integer()
task_reward_limit = Integer()
study_hourly_reward = Integer()
study_hourly_live_bonus = Integer()
daily_study_cap = Integer()
renting_price = Integer()
renting_category = Integer()
renting_cap = Integer()
renting_role = Integer()
renting_sync_perms = Bool()
accountability_category = Integer()
accountability_lobby = Integer()
accountability_bonus = Integer()
accountability_reward = Integer()
accountability_price = Integer()
video_studyban = Bool()
video_grace_period = Integer()
greeting_channel = Integer()
greeting_message = String()
returning_message = String()
starting_funds = Integer()
persist_roles = Bool()
pomodoro_channel = Integer()
name = String()
first_joined_at = Timestamp()
left_at = Timestamp()
locale = String()
force_locale = Bool()
unranked_rows = Table('unranked_rows')
donator_roles = Table('donator_roles')
member_ranks = Table('member_ranks')
class Member(RowModel):
"""
Member model, representing configuration data for a single member.
Schema
------
CREATE TABLE members(
guildid BIGINT,
userid BIGINT,
tracked_time INTEGER DEFAULT 0,
coins INTEGER DEFAULT 0,
workout_count INTEGER DEFAULT 0,
revision_mute_count INTEGER DEFAULT 0,
last_workout_start TIMESTAMP,
last_study_badgeid INTEGER REFERENCES study_badges ON DELETE SET NULL,
video_warned BOOLEAN DEFAULT FALSE,
display_name TEXT,
first_joined TIMESTAMPTZ DEFAULT now(),
last_left TIMESTAMPTZ,
_timestamp TIMESTAMP DEFAULT (now() at time zone 'utc'),
PRIMARY KEY(guildid, userid)
);
CREATE INDEX member_timestamps ON members (_timestamp);
"""
_tablename_ = 'members'
_cache_: WeakCache[tuple[int, int], 'CoreData.Member'] = WeakCache(TTLCache(5000, ttl=60*5))
guildid = Integer(primary=True)
userid = Integer(primary=True)
tracked_time = Integer()
coins = Integer()
workout_count = Integer()
revision_mute_count = Integer()
last_workout_start = Timestamp()
last_study_badgeid = Integer()
video_warned = Bool()
display_name = String()
first_joined = Timestamp()
last_left = Timestamp()
_timestamp = Timestamp()
@classmethod
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
"""
Safely add pending coins to a list of members.
Arguments
---------
pending:
List of tuples of the form `(guildid, userid, pending_coins)`.
"""
query = sql.SQL("""
UPDATE members
SET
coins = LEAST(coins + t.coin_diff, 2147483647)
FROM
(VALUES {})
AS
t (guildid, userid, coin_diff)
WHERE
members.guildid = t.guildid
AND
members.userid = t.userid
RETURNING *
""").format(
sql.SQL(', ').join(
sql.SQL("({}, {}, {})").format(sql.Placeholder(), sql.Placeholder(), sql.Placeholder())
for _ in pending
)
)
# TODO: Replace with copy syntax/query?
conn = await cls.table.connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*pending))
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
@classmethod
async def get_member_rank(cls, guildid, userid, untracked):
"""
Get the time and coin ranking for the given member, ignoring the provided untracked members.
"""
conn = await cls.table.connector.get_connection()
async with conn.cursor() as curs:
await curs.execute(
"""
SELECT
time_rank, coin_rank
FROM (
SELECT
userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals
WHERE
guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s
""",
(guildid, tuple(untracked), userid)
)
return (await curs.fetchone()) or (None, None)

View File

@@ -0,0 +1,7 @@
from settings.groups import ModelSettings, SettingDotDict
from .data import CoreData
class GuildSettings(ModelSettings):
_settings = SettingDotDict()
model = CoreData.Guild

153
src/core/lion.py Normal file
View File

@@ -0,0 +1,153 @@
from typing import Optional
from cachetools import LRUCache
import discord
from meta import LionCog, LionBot, LionContext
from settings import InteractiveSetting
from utils.lib import utc_now
from data import WeakCache
from .data import CoreData
from .user_settings import UserSettings
from .guild_settings import GuildSettings
class Lion:
"""
A Lion is a high level representation of a Member in the LionBot paradigm.
All members interacted with by the application should be available as Lions.
It primarily provides an interface to the User and Member data.
Lion also provides centralised access to various Member properties and methods,
that would normally be served by other cogs.
Many Lion methods may only be used when the required cogs and extensions are loaded.
A Lion may exist without a Bot instance or a Member in cache,
although the functionality available will be more limited.
There is no guarantee that a corresponding discord Member actually exists.
"""
__slots__ = ('bot', 'data', 'user_data', 'guild_data', '_member', '__weakref__')
def __init__(self, bot: LionBot, data: CoreData.Member, user_data: CoreData.User, guild_data: CoreData.Guild):
self.bot = bot
self.data = data
self.user_data = user_data
self.guild_data = guild_data
self._member: Optional[discord.Member] = None
# Data properties
@property
def key(self):
return (self.data.guildid, self.data.userid)
@property
def guildid(self):
return self.data.guildid
@property
def userid(self):
return self.data.userid
@classmethod
def get(cls, guildid, userid):
return cls._cache_.get((guildid, userid), None)
# ModelSettings interfaces
@property
def guild_settings(self):
return GuildSettings(self.guildid, self.guild_data, bot=self.bot)
@property
def user_settings(self):
return UserSettings(self.userid, self.user_data, bot=self.bot)
# Setting interfaces
# Each of these return an initialised member setting
@property
def timezone(self):
pass
@property
def locale(self):
pass
# Time utilities
@property
def now(self):
"""
Returns current time-zone aware time for the member.
"""
pass
# Discord data cache
async def touch_discord_models(self, member: discord.Member):
"""
Update the stored discord data from the given user or member object.
Intended to be used when we get member data from events that may not be available in cache.
"""
# Can we do these in one query?
if member.guild and (self.guild_data.name != member.guild.name):
await self.guild_data.update(name=member.guild.name)
avatar_key = member.avatar.key if member.avatar else None
await self.user_data.update(avatar_hash=avatar_key, name=member.name, last_seen=utc_now())
if member.display_name != self.data.display_name:
await self.data.update(display_name=member.display_name)
async def get_member(self) -> Optional[discord.Member]:
"""
Retrieve the member object for this Lion, if possible.
If the guild or member cannot be retrieved, returns None.
"""
guild = self.bot.get_guild(self.guildid)
if guild is not None:
member = guild.get_member(self.userid)
if member is None:
try:
member = await guild.fetch_member(self.userid)
except discord.HTTPException:
pass
return member
class Lions(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
# Full Lions cache
# Don't expire Lions with strong references
self._cache_: WeakCache[tuple[int, int], 'Lion'] = WeakCache(LRUCache(5000))
self._settings_: dict[str, InteractiveSetting] = {}
async def fetch(self, guildid, userid) -> Lion:
"""
Fetch or create the given Member.
If the guild or user row doesn't exist, also creates it.
Relies on the core cog existing, to retrieve the core data.
"""
# TODO: Find a way to reduce this to one query, while preserving cache
lion = self._cache_.get((guildid, userid))
if lion is None:
if self.bot.core:
data = self.bot.core.data
else:
raise ValueError("Cannot fetch Lion before core module is attached.")
guild = await data.Guild.fetch_or_create(guildid)
user = await data.User.fetch_or_create(userid)
member = await data.Member.fetch_or_create(guildid, userid)
lion = Lion(self.bot, member, user, guild)
self._cache_[(guildid, userid)] = lion
return lion
def add_model_setting(self, setting: InteractiveSetting):
self._settings_[setting.__class__.__name__] = setting
return setting

View File

@@ -0,0 +1,7 @@
from settings.groups import ModelSettings, SettingDotDict
from .data import CoreData
class UserSettings(ModelSettings):
_settings = SettingDotDict()
model = CoreData.User

9
src/data/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
from .conditions import Condition, condition, NULL
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum
from .queries import ORDER, NULLS, JOINTYPE

32
src/data/adapted.py Normal file
View File

@@ -0,0 +1,32 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection()
if connection is None:
raise ValueError("Cannot Init without connection.")
info = await EnumInfo.fetch(connection, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))

45
src/data/base.py Normal file
View File

@@ -0,0 +1,45 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

155
src/data/columns.py Normal file
View File

@@ -0,0 +1,155 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary
self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
# Only allow setting the owner once
self.name = self.name or name
self.owner = owner
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass
class Timestamp(Column[datetime]):
pass

214
src/data/conditions.py Normal file
View File

@@ -0,0 +1,214 @@
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from .base import Expression, RawExpr
"""
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
NULL = None
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

58
src/data/connector.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
import logging
import psycopg as psq
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
"""
Get the current active connection.
This should never be cached outside of a transaction.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
return self.conn
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
)
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

42
src/data/cursor.py Normal file
View File

@@ -0,0 +1,42 @@
import logging
from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import pgconn_encoding
logger = logging.getLogger(__name__)
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
try:
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

46
src/data/database.py Normal file
View File

@@ -0,0 +1,46 @@
from typing import TypeVar
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self)
self.registries[registry.name] = registry
return registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

320
src/data/models.py Normal file
View File

@@ -0,0 +1,320 @@
from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
self.model._make_rows(*data)
return data[0]
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None)
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

592
src/data/queries.py Normal file
View File

@@ -0,0 +1,592 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
ABC for an executable query statement.
"""
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def with_no_adapter(self):
"""
Sets the adapater to the identity.
"""
self._adapter = self._no_adapter
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult:
"""
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
conn = await self.connector.get_connection()
else:
conn = self.conn
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'tableid',
'condition', '_extra', '_limit', '_order', '_joins'
)
def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tableid: sql.Identifier = tableid
def options(self, **kwargs):
"""
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
return self
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: list[Expression] = []
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
"""
Add a single sort expression to the query.
This method stacks.
"""
if isinstance(expr, Expression):
string, values = expr.as_tuple()
else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property
def _order_section(self) -> Optional[Expression]:
if self._order:
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
return expr
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values', '_conflict')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
self._conflict: Optional[Expression] = None
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
return self
def on_conflict(self, ignore=False):
# TODO lots more we can do here
# Maybe return a Conflict object that can chain itself (not the query)
if ignore:
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
return self
@property
def _conflict_section(self) -> Optional[Expression]:
if self._conflict is not None:
e, v = self._conflict.as_tuple()
expr = RawExpr(
sql.SQL("ON CONFLICT {}").format(
e
),
v
)
return expr
return None
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
values_str = sql.SQL(',').join(single_value_str * len(self._values))
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=self.tableid,
columns=columns,
values_str=values_str
)
sections = [
RawExpr(base, tuple(chain(*self._values))),
self._conflict_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQuery[QueryResult]):
"""
Select rows from a table matching provided conditions.
"""
__slots__ = ('_columns',)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=self.tableid
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=self.tableid,
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), (value,))
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=self.tableid,
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

102
src/data/registry.py Normal file
View File

@@ -0,0 +1,102 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for _, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Optional[Connector] = None
self.name: str = name if name is not None else self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

95
src/data/table.py Normal file
View File

@@ -0,0 +1,95 @@
from typing import Optional
from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name
self.schema: str = schema
self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.identifier, *args, **kwargs)

1
src/gui Submodule

Submodule src/gui added at 76659f1193

200
src/meta/LionBot.py Normal file
View File

@@ -0,0 +1,200 @@
from typing import List, Optional, TYPE_CHECKING
import logging
import asyncio
import discord
from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
from aiohttp import ClientSession
from data import Database
from .config import Conf
from .logger import logging_context, log_context, log_action_stack
from .context import context
from .LionContext import LionContext
from .LionTree import LionTree
from .errors import HandledException, SafeCancellation
if TYPE_CHECKING:
from core import CoreCog
logger = logging.getLogger(__name__)
class LionBot(Bot):
def __init__(
self, *args, appname: str, shardname: str, db: Database, config: Conf,
initial_extensions: List[str], web_client: ClientSession, app_ipc,
testing_guilds: List[int] = [], translator=None, **kwargs
):
kwargs.setdefault('tree_cls', LionTree)
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.config = config
self.app_ipc = app_ipc
self.core: Optional['CoreCog'] = None
self.translator = translator
async def setup_hook(self) -> None:
log_context.set(f"APP: {self.application_id}")
await self.app_ipc.connect()
if self.translator is not None:
await self.tree.set_translator(self.translator)
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)
async def add_cog(self, cog: Cog, **kwargs):
with logging_context(action=f"Attach {cog.__cog_name__}"):
logger.info(f"Attaching Cog {cog.__cog_name__}")
await super().add_cog(cog, **kwargs)
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
async def load_extension(self, name, *, package=None, **kwargs):
with logging_context(action=f"Load {name.strip('.')}"):
logger.info(f"Loading extension {name} in package {package}.")
await super().load_extension(name, package=package, **kwargs)
logger.debug(f"Loaded extension {name} in package {package}.")
async def start(self, token: str, *, reconnect: bool = True):
with logging_context(action="Login"):
await self.login(token)
with logging_context(stack=["Running"]):
await self.connect(reconnect=reconnect)
async def on_ready(self):
logger.info(
f"Logged in as {self.application.name}\n"
f"Application id {self.application.id}\n"
f"Shard Talk identifier {self.shardname}\n"
"------------------------------\n"
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
"------------------------------\n"
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
"Ready to take commands!\n",
extra={'action': 'Ready'}
)
async def get_context(self, origin, /, *, cls=MISSING):
if cls is MISSING:
cls = LionContext
ctx = await super().get_context(origin, cls=cls)
context.set(ctx)
return ctx
async def on_command(self, ctx: LionContext):
logger.info(
f"Executing command '{ctx.command.qualified_name}' "
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
f"with arguments {ctx.args} and kwargs {ctx.kwargs}.",
extra={'with_ctx': True}
)
async def on_command_error(self, ctx, exception):
# TODO: Some of these could have more user-feedback
cmd_str = str(ctx.command)
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
try:
raise exception
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
try:
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
original = exception.original.original
raise original
else:
original = exception.original
raise original
except HandledException:
pass
except SafeCancellation:
if original.msg:
try:
await ctx.error_reply(original.msg)
except Exception:
pass
logger.debug(
f"Caught a safe cancellation: {original.details}",
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except Exception:
# We can't send anything at all. Exit quietly, but log.
logger.warning(
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except discord.HTTPException:
logger.warning(
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
exc_info=True,
extra={'action': 'BotError', 'with_ctx': True}
)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
pass
except Exception:
logger.exception(
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
error_embed = discord.Embed(title="Something went wrong!")
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue should be fixed soon.\n"
"If the error persists, please contact our support team and give them the following number: "
f"`{ctx.interaction.id}`"
)
try:
await ctx.error_reply(embed=error_embed)
except Exception:
pass
finally:
exception.original = HandledException(exception.original)
except CheckFailure:
logger.debug(
f"Command failed check: {exception}",
extra={'action': 'BotError', 'with_ctx': True}
)
try:
await ctx.error_rely(exception.message)
except Exception:
pass
except Exception:
# Completely unknown exception outside of command invocation!
# Something is very wrong here, don't attempt user interaction.
logger.exception(
f"Caught an unknown top-level exception while executing: {cmd_str}",
extra={'action': 'BotError', 'with_ctx': True}
)
def add_command(self, command):
if hasattr(command, '_placeholder_group_'):
return
super().add_command(command)

58
src/meta/LionCog.py Normal file
View File

@@ -0,0 +1,58 @@
from typing import Any
from discord.ext.commands import Cog
from discord.ext import commands as cmds
class LionCog(Cog):
# A set of other cogs that this cog depends on
depends_on: set['LionCog'] = set()
_placeholder_groups_: set[str]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._placeholder_groups_ = set()
for base in reversed(cls.__mro__):
for elem, value in base.__dict__.items():
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
cls._placeholder_groups_.add(value.name)
def __new__(cls, *args: Any, **kwargs: Any):
# Patch to ensure no placeholder groups are in the command list
self = super().__new__(cls)
self.__cog_commands__ = [
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
]
return self
async def _inject(self, bot, *args, **kwargs):
if self.depends_on:
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
return await super()._inject(bot, *args, *kwargs)
@classmethod
def placeholder_group(cls, group: cmds.HybridGroup):
group._placeholder_group_ = True
return group
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
"""
Crossload a placeholder group's commands into the target group
"""
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
raise ValueError("Placeholder and target groups my be HypridGroups.")
if placeholder_group.name not in self._placeholder_groups_:
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
if target_group.app_command is None:
raise ValueError("Target group has no app_command to crossload into.")
for command in placeholder_group.commands:
placeholder_group.remove_command(command.name)
target_group.remove_command(command.name)
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
command.app_command = acmd
target_group.add_command(command)

190
src/meta/LionContext.py Normal file
View File

@@ -0,0 +1,190 @@
import types
import logging
from collections import namedtuple
from typing import Optional, TYPE_CHECKING
import discord
from discord.ext.commands import Context
if TYPE_CHECKING:
from .LionBot import LionBot
from core.lion import Lion
logger = logging.getLogger(__name__)
"""
Stuff that might be useful to implement (see cmdClient):
sent_messages cache
tasks cache
error reply
usage
interaction cache
View cache?
setting access
"""
FlatContext = namedtuple(
'FlatContext',
('message',
'interaction',
'guild',
'author',
'alias',
'prefix',
'failed')
)
class LionContext(Context['LionBot']):
"""
Represents the context a command is invoked under.
Extends Context to add Lion-specific methods and attributes.
Also adds several contextual wrapped utilities for simpler user during command invocation.
"""
alion: 'Lion'
def __repr__(self):
parts = {}
if self.interaction is not None:
parts['iid'] = self.interaction.id
parts['itype'] = f"\"{self.interaction.type.name}\""
if self.message is not None:
parts['mid'] = self.message.id
if self.author is not None:
parts['uid'] = self.author.id
parts['uname'] = f"\"{self.author.name}\""
if self.channel is not None:
parts['cid'] = self.channel.id
parts['cname'] = f"\"{self.channel.name}\""
if self.guild is not None:
parts['gid'] = self.guild.id
parts['gname'] = f"\"{self.guild.name}\""
if self.command is not None:
parts['cmd'] = f"\"{self.command.qualified_name}\""
if self.invoked_with is not None:
parts['alias'] = f"\"{self.invoked_with}\""
if self.command_failed:
parts['failed'] = self.command_failed
return "<LionContext: {}>".format(
' '.join(f"{name}={value}" for name, value in parts.items())
)
def flatten(self):
"""Flat pure-data context information, for caching and logging."""
return FlatContext(
self.message.id,
self.interaction.id if self.interaction is not None else None,
self.guild.id if self.guild is not None else None,
self.author.id if self.author is not None else None,
self.channel.id if self.channel is not None else None,
self.invoked_with,
self.prefix,
self.command_failed
)
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
"""
setattr(cls, util_func.__name__, util_func)
logger.debug(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrapped = Wrappable(util_func)
setattr(cls, util_func.__name__, wrapped)
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
return wrapped
async def error_reply(self, content: Optional[str] = None, **kwargs):
if content and 'embed' not in kwargs:
embed = discord.Embed(
colour=discord.Colour.red(),
description=content
)
kwargs['embed'] = embed
content = None
# Expect this may be run in highly unusual circumstances.
# This should never error, or at least handle all errors.
try:
await self.reply(content=content, **kwargs)
except discord.HTTPException:
pass
except Exception:
logger.exception(
"Unknown exception in 'error_reply'.",
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
)
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
logger.debug(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
extra={'action': "Wrap Util"}
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
logger.debug(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
extra={'action': "Unwrap Util"}
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
LionContext.reply = Wrappable(LionContext.reply)
# @LionContext.reply.add_wrapper
# async def think(func, ctx, *args, **kwargs):
# await ctx.channel.send("thinking")
# await func(ctx, *args, **kwargs)

79
src/meta/LionTree.py Normal file
View File

@@ -0,0 +1,79 @@
import logging
from discord import Interaction
from discord.app_commands import CommandTree
from discord.app_commands.errors import AppCommandError, CommandInvokeError
from discord.enums import InteractionType
from discord.app_commands.namespace import Namespace
from .logger import logging_context
from .errors import SafeCancellation
logger = logging.getLogger(__name__)
class LionTree(CommandTree):
async def on_error(self, interaction, error) -> None:
try:
if isinstance(error, CommandInvokeError):
raise error.original
else:
raise error
except SafeCancellation:
# Assume this has already been handled
pass
except Exception:
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
async def _call(self, interaction):
with logging_context(context=f"iid: {interaction.id}"):
if not await self.interaction_check(interaction):
interaction.command_failed = True
return
data = interaction.data # type: ignore
type = data.get('type', 1)
if type != 1:
# Context menu command...
await self._call_context_menu(interaction, data, type)
return
command, options = self._get_app_command_options(data)
# Pre-fill the cached slot to prevent re-computation
interaction._cs_command = command
# At this point options refers to the arguments of the command
# and command refers to the class type we care about
namespace = Namespace(interaction, data.get('resolved', {}), options)
# Same pre-fill as above
interaction._cs_namespace = namespace
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
if interaction.type is InteractionType.autocomplete:
with logging_context(action=f"Acmp {command.qualified_name}"):
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
if focused is None:
raise AppCommandError(
'This should not happen, but there is no focused element. This is a Discord bug.'
)
await command._invoke_autocomplete(interaction, focused, namespace)
return
with logging_context(action=f"Run {command.qualified_name}"):
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:
interaction.command_failed = True
await command._invoke_error_handlers(interaction, e)
await self.on_error(interaction, e)
else:
if not interaction.command_failed:
self.client.dispatch('app_command_completion', interaction, command)
finally:
if interaction.command_failed:
logger.debug("Command completed with errors.")
else:
logger.debug("Command completed without errors.")

15
src/meta/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from .LionBot import LionBot
from .LionCog import LionCog
from .LionContext import LionContext
from .LionTree import LionTree
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
from .config import conf, configEmoji
from .args import args
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
from .context import context, ctx_bot
from . import sharding
from . import logger
from . import app

46
src/meta/app.py Normal file
View File

@@ -0,0 +1,46 @@
"""
appname: str
The base identifer for this application.
This identifies which services the app offers.
shardname: str
The specific name of the running application.
Only one process should be connecteded with a given appname.
For the bot apps, usually specifies the shard id and shard number.
"""
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
from . import sharding, conf
from .logger import log_app
from .ipc.client import AppClient
from .args import args
appname = conf.data['appid']
appid = appname # backwards compatibility
def appname_from_shard(shardid):
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
return appname
def shard_from_appname(appname: str):
return int(appname.rsplit('_', maxsplit=1)[-1])
shardname = appname_from_shard(sharding.shard_number)
log_app.set(shardname)
shard_talk = AppClient(
shardname,
appname,
{'host': args.host, 'port': args.port},
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
)
@shard_talk.register_route()
async def ping():
return "Pong!"

35
src/meta/args.py Normal file
View File

@@ -0,0 +1,35 @@
import argparse
from constants import CONFIG_FILE
# ------------------------------
# Parsed commandline arguments
# ------------------------------
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf',
dest='config',
default=CONFIG_FILE,
help="Path to configuration file."
)
parser.add_argument(
'--shard',
dest='shard',
default=None,
type=int,
help="Shard number to run, if applicable."
)
parser.add_argument(
'--host',
dest='host',
default='127.0.0.1',
help="IP address to run the app listener on."
)
parser.add_argument(
'--port',
dest='port',
default='5001',
help="Port to run the app listener on."
)
args = parser.parse_args()

137
src/meta/config.py Normal file
View File

@@ -0,0 +1,137 @@
from discord import PartialEmoji
import configparser as cfgp
from .args import args
class configEmoji(PartialEmoji):
__slots__ = ('fallback',)
def __init__(self, *args, fallback=None, **kwargs):
super().__init__(*args, **kwargs)
self.fallback = fallback
@classmethod
def from_str(cls, emojistr: str):
"""
Parses emoji strings of one of the following forms
`<a:name:id> or fallback`
`<:name:id> or fallback`
`<a:name:id>`
`<:name:id>`
"""
splits = emojistr.rsplit(' or ', maxsplit=1)
fallback = splits[1] if len(splits) > 1 else None
emojistr = splits[0].strip('<> ')
animated, name, id = emojistr.split(':')
return cls(
name=name,
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated),
id=int(id) if id else None
)
class MapDotProxy:
"""
Allows dot access to an underlying Mappable object.
"""
__slots__ = ("_map", "_converter")
def __init__(self, mappable, converter=None):
self._map = mappable
self._converter = converter
def __getattribute__(self, key):
_map = object.__getattribute__(self, '_map')
if key == '_map':
return _map
if key in _map:
_converter = object.__getattribute__(self, '_converter')
if _converter:
return _converter(_map[key])
else:
return _map[key]
else:
return object.__getattribute__(_map, key)
def __getitem__(self, key):
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
"emoji": configEmoji.from_str,
}
)
self.config.read(configfile)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = MapDotProxy(self.config[self.section_name])
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
self.emojis = MapDotProxy(
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
converter=configEmoji.from_str
)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
return self.config[section.upper()]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)
conf = Conf(args.config, 'STUDYLION')

20
src/meta/context.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Namespace for various global context variables.
Allows asyncio callbacks to accurately retrieve information about the current state.
"""
from typing import TYPE_CHECKING, Optional
from contextvars import ContextVar
if TYPE_CHECKING:
from .LionBot import LionBot
from .LionContext import LionContext
# Contains the current command context, if applicable
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
# Contains the current LionBot instance
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)

64
src/meta/errors.py Normal file
View File

@@ -0,0 +1,64 @@
from typing import Optional
from string import Template
class SafeCancellation(Exception):
"""
Raised to safely cancel execution of the current operation.
If not caught, is expected to be propagated to the Tree and safely ignored there.
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
The error handler should then set the `msg` to None, to avoid double handling.
Debugging information should go in `details`, to be logged by a top-level error handler.
"""
default_message = ""
@property
def msg(self):
return self._msg if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
self._msg: Optional[str] = _msg
self.details: str = details if details is not None else self.msg
super().__init__(**kwargs)
class UserInputError(SafeCancellation):
"""
A SafeCancellation induced from unparseable user input.
"""
default_message = "Could not understand your input."
@property
def msg(self):
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
self.info = info
super().__init__(_msg, **kwargs)
class UserCancelled(SafeCancellation):
"""
A SafeCancellation induced from manual user cancellation.
Usually silent.
"""
default_msg = None
class ResponseTimedOut(SafeCancellation):
"""
A SafeCancellation induced from a user interaction time-out.
"""
default_msg = "Session timed out waiting for input."
class HandledException(SafeCancellation):
"""
Sentinel class to indicate to error handlers that this exception has been handled.
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
"""
def __init__(self, exc=None, **kwargs):
self.exc = exc
super().__init__(**kwargs)

2
src/meta/ipc/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .client import AppClient, AppPayload, AppRoute
from .server import AppServer

236
src/meta/ipc/client.py Normal file
View File

@@ -0,0 +1,236 @@
from typing import Optional, TypeAlias, Any
import asyncio
import logging
import pickle
from ..logger import logging_context
logger = logging.getLogger(__name__)
Address: TypeAlias = dict[str, Any]
class AppClient:
routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]]
def __init__(self, appid: str, basename: str, client_address: Address, server_address: Address):
self.appid = appid # String identifier for this ShardTalk client
self.basename = basename # Prefix used to recognise app peers
self.address = client_address
self.server_address = server_address
self.peers = {appid: client_address} # appid -> address
self._listener: Optional[asyncio.Server] = None # Local client server
self._server = None # Connection to the registry server
self._keepalive = None
self.register_route('new_peer')(self.new_peer)
self.register_route('drop_peer')(self.drop_peer)
self.register_route('peer_list')(self.peer_list)
@property
def my_peers(self):
return {peerid: peer for peerid, peer in self.peers.items() if peerid.startswith(self.basename)}
def register_route(self, name=None):
def wrapper(coro):
route = AppRoute(coro, client=self, name=name)
self.routes[route.name] = route
return route
return wrapper
async def server_connection(self):
"""Establish a connection to the registry server"""
try:
reader, writer = await asyncio.open_connection(**self.server_address)
payload = ('connect', (), {'appid': self.appid, 'address': self.address})
writer.write(pickle.dumps(payload))
writer.write(b'\n')
await writer.drain()
data = await reader.readline()
peers = pickle.loads(data)
self.peers = peers
self._server = (reader, writer)
except Exception:
logger.exception(
"Could not connect to registry server. Trying again in 30 seconds.",
extra={'action': 'Connect'}
)
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
else:
logger.debug(
"Connected to the registry server, launching keepalive.",
extra={'action': 'Connect'}
)
self._keepalive = asyncio.create_task(self._server_keepalive())
async def _server_keepalive(self):
with logging_context(action='Keepalive'):
if self._server is None:
raise ValueError("Cannot keepalive non-existent server!")
reader, write = self._server
try:
await reader.read()
except Exception:
logger.exception("Lost connection to address server. Reconnecting...")
else:
# Connection ended or broke
logger.info("Lost connection to address server. Reconnecting...")
await asyncio.sleep(30)
asyncio.create_task(self.server_connection())
async def new_peer(self, appid, address):
self.peers[appid] = address
async def peer_list(self, peers):
self.peers = peers
async def drop_peer(self, appid):
self.peers.pop(appid, None)
async def close(self):
# Close connection to the server
# TODO
...
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
with logging_context(action=f"Req {appid}"):
try:
if appid not in self.peers:
raise ValueError(f"Peer '{appid}' not found.")
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
address = self.peers[appid]
reader, writer = await asyncio.open_connection(**address)
writer.write(payload.encoded())
await writer.drain()
writer.write_eof()
if wait_for_reply:
result = await reader.read()
writer.close()
decoded = payload.route.decode(result)
return decoded
else:
return None
except Exception:
logging.exception(f"Failed to send request to {appid}'")
return None
async def requestall(self, payload, except_self=True, only_my_peers=True):
with logging_context(action="Broadcast"):
peerlist = self.my_peers if only_my_peers else self.peers
results = await asyncio.gather(
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
return_exceptions=True
)
return dict(zip(self.peers.keys(), results))
async def handle_request(self, reader, writer):
with logging_context(action="SERV"):
data = await reader.read()
loaded = pickle.loads(data)
route, args, kwargs = loaded
with logging_context(action=f"SERV {route}"):
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
try:
await self.routes[route].run((reader, writer), args, kwargs)
except Exception:
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
else:
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
writer.write_eof()
async def connect(self):
"""
Start the local peer server.
Connect to the address server.
"""
with logging_context(stack=['ShardTalk']):
# Start the client server
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
logger.info(f"Serving on {self.address}")
await self.server_connection()
class AppPayload:
__slots__ = ('route', 'args', 'kwargs')
def __init__(self, route, *args, **kwargs):
self.route = route
self.args = args
self.kwargs = kwargs
def __await__(self):
return self.route.execute(*self.args, **self.kwargs).__await__()
def encoded(self):
return pickle.dumps((self.route.name, self.args, self.kwargs))
async def send(self, appid, **kwargs):
return await self.route._client.request(appid, self, **kwargs)
async def broadcast(self, **kwargs):
return await self.route._client.requestall(self, **kwargs)
class AppRoute:
__slots__ = ('func', 'name', '_client')
def __init__(self, func, client=None, name=None):
self.func = func
self.name = name or func.__name__
self._client = client
def __call__(self, *args, **kwargs):
return AppPayload(self, *args, **kwargs)
def encode(self, output):
return pickle.dumps(output)
def decode(self, encoded):
# TODO: Handle exceptions here somehow
if len(encoded) > 0:
return pickle.loads(encoded)
else:
return ''
def encoder(self, func):
self.encode = func
def decoder(self, func):
self.decode = func
async def execute(self, *args, **kwargs):
"""
Execute the underlying function, with the given arguments.
"""
return await self.func(*args, **kwargs)
async def run(self, connection, args, kwargs):
"""
Run the route, with the given arguments, using the given connection.
"""
# TODO: ContextVar here for logging? Or in handle_request?
# Get encoded result
# TODO: handle exceptions in the execution process
try:
result = await self.execute(*args, **kwargs)
payload = self.encode(result)
except Exception:
logger.exception(f"Exception occured running route '{self.name}' with args: {args} and kwargs: {kwargs}")
payload = b''
_, writer = connection
writer.write(payload)
await writer.drain()
writer.close()

175
src/meta/ipc/server.py Normal file
View File

@@ -0,0 +1,175 @@
import asyncio
import pickle
import logging
import string
import random
from ..logger import log_context, log_app, logging_context
logger = logging.getLogger(__name__)
uuid_alphabet = string.ascii_lowercase + string.digits
def short_uuid():
return ''.join(random.choices(uuid_alphabet, k=10))
class AppServer:
routes = {} # route name -> bound method
def __init__(self):
self.clients = {} # AppID -> (info, connection)
self.route('ping')(self.route_ping)
self.route('whereis')(self.route_whereis)
self.route('peers')(self.route_peers)
self.route('connect')(self.client_connection)
@classmethod
def route(cls, route_name):
"""
Decorator to add a route to the server.
"""
def wrapper(coro):
cls.routes[route_name] = coro
return coro
return wrapper
async def route_ping(self, connection):
"""
Pong.
"""
reader, writer = connection
writer.write(b"Pong")
writer.write_eof()
async def route_whereis(self, connection, appid):
"""
Return an address for the given client appid.
Returns None if the client does not have a connection.
"""
reader, writer = connection
if appid in self.clients:
writer.write(pickle.dumps(self.clients[appid][0]))
else:
writer.write(b'')
writer.write_eof()
async def route_peers(self, connection):
"""
Send back a map of current peers.
"""
reader, writer = connection
peers = self.peer_list()
payload = pickle.dumps(('peer_list', (peers,)))
writer.write(payload)
writer.write_eof()
async def client_connection(self, connection, appid, address):
"""
Register and hold a new client connection.
"""
with logging_context(action=f"CONN {appid}"):
reader, writer = connection
# Add the new client
self.clients[appid] = (address, connection)
# Send the new client a client list
peers = self.peer_list()
writer.write(pickle.dumps(peers))
writer.write(b'\n')
await writer.drain()
# Announce the new client to everyone
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
# Keep the connection open until socket closed or EOF (indicating client death)
try:
await reader.read()
finally:
# Connection ended or it broke
logger.info(f"Lost client '{appid}'")
await self.deregister_client(appid)
async def handle_connection(self, reader, writer):
data = await reader.readline()
route, args, kwargs = pickle.loads(data)
rqid = short_uuid()
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
if route in self.routes:
# Execute route
try:
await self.routes[route]((reader, writer), *args, **kwargs)
except Exception:
logger.exception(f"AppServer recieved exception during route '{route}'")
else:
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
def peer_list(self):
return {appid: address for appid, (address, _) in self.clients.items()}
async def deregister_client(self, appid):
self.clients.pop(appid, None)
await self.broadcast('drop_peer', (), {'appid': appid})
async def broadcast(self, route, args, kwargs):
with logging_context(action="broadcast"):
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
payload = pickle.dumps((route, args, kwargs))
if self.clients:
await asyncio.gather(
*(self._send(appid, payload) for appid in self.clients),
return_exceptions=True
)
async def message_client(self, appid, route, args, kwargs):
"""
Send a message to client `appid` along `route` with given arguments.
"""
with logging_context(action=f"MSG {appid}"):
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
if appid not in self.clients:
raise ValueError(f"Client '{appid}' is not connected.")
payload = pickle.dumps((route, args, kwargs))
return await self._send(appid, payload)
async def _send(self, appid, payload):
"""
Send the encoded `payload` to the client `appid`.
"""
address, _ = self.clients[appid]
try:
reader, writer = await asyncio.open_connection(**address)
writer.write(payload)
writer.write_eof()
await writer.drain()
writer.close()
except Exception as ex:
# TODO: Close client if we can't connect?
logger.exception(f"Failed to send message to '{appid}'")
raise ex
async def start(self, address):
log_app.set("APPSERVER")
with logging_context(stack=["SERV"]):
server = await asyncio.start_server(self.handle_connection, **address)
logger.info(f"Serving on {address}")
async with server:
await server.serve_forever()
async def start_server():
address = {'host': '127.0.0.1', 'port': '5000'}
server = AppServer()
await server.start(address)
if __name__ == '__main__':
asyncio.run(start_server())

324
src/meta/logger.py Normal file
View File

@@ -0,0 +1,324 @@
import sys
import logging
import asyncio
from typing import List
from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue
from contextlib import contextmanager
from io import StringIO
from functools import wraps
from contextvars import ContextVar
from discord import Webhook, File
import aiohttp
from .config import conf
from . import sharding
from .context import context
from utils.lib import utc_now
log_logger = logging.getLogger(__name__)
log_logger.propagate = False
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
@contextmanager
def logging_context(context=None, action=None, stack=None):
if context is not None:
context_t = log_context.set(context)
if action is not None:
astack = log_action_stack.get()
log_action_stack.set(astack + [action])
if stack is not None:
actions_t = log_action_stack.set(stack)
try:
yield
finally:
if context is not None:
log_context.reset(context_t)
if stack is not None:
log_action_stack.reset(actions_t)
if action is not None:
log_action_stack.set(astack)
def log_wrap(**kwargs):
def decorator(func):
@wraps(func)
async def wrapped(*w_args, **w_kwargs):
with logging_context(**kwargs):
return await func(*w_args, **w_kwargs)
return wrapped
return decorator
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
'[%(cyan)%(app)-15s%(reset)]' +
'[%(cyan)%(context)-24s%(reset)]' +
'[%(cyan)%(actionstr)-22s%(reset)]' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(ctxstr)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
# datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ThreadFilter(logging.Filter):
def __init__(self, thread_name):
super().__init__("")
self.thread = thread_name
def filter(self, record):
# non-zero return means we log this message
return 1 if record.threadName == self.thread else 0
class ContextInjection(logging.Filter):
def filter(self, record):
# These guards are to allow override through _extra
# And to ensure the injection is idempotent
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'actionstr'):
action_stack = log_action_stack.get()
if hasattr(record, 'action'):
action_stack = (*action_stack, record.action)
if action_stack:
record.actionstr = ''.join(action_stack)
else:
record.actionstr = "Unknown Action"
if not hasattr(record, 'app'):
record.app = log_app.get()
if not hasattr(record, 'ctx'):
if ctx := context.get():
record.ctx = repr(ctx)
else:
record.ctx = None
if getattr(record, 'with_ctx', False) and record.ctx:
record.ctxstr = '\n' + record.ctx
else:
record.ctxstr = ""
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
log_logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
log_logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
except asyncio.CancelledError:
raise
except Exception:
self.handleError(record)
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, batch=False, loop=None):
super().__init__()
self.webhook_url = webhook_url
self.batched = ""
self.batch = batch
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
asyncio.create_task(self.post(record))
async def post(self, record):
log_context.set("Webhook Logger")
log_action_stack.set(["Logging"])
log_app.set(record.app)
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
context = f"\n# Context: {record.ctx}" if record.ctx else ""
message = f"{header}\n{record.msg}{context}"
if len(message) > 1900:
as_file = True
else:
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 1500:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += message
if len(self.batched) + len(message) > 1500:
await self._send_batched_now()
else:
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(ex)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(ex)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, message, as_file=False):
async with aiohttp.ClientSession() as session:
webhook = Webhook.from_url(self.webhook_url, session=session)
if as_file or len(message) > 2000:
with StringIO(message) as fp:
fp.seek(0)
await webhook.send(
file=File(fp, filename="logs.md"),
username=log_app.get()
)
else:
await webhook.send(message, username=log_app.get())
handlers = []
if webhook := conf.logging['general_log']:
handler = WebHookHandler(webhook, batch=True)
handlers.append(handler)
if webhook := conf.logging['error_log']:
handler = WebHookHandler(webhook, batch=False)
handler.setLevel(logging.ERROR)
handlers.append(handler)
if webhook := conf.logging['critical_log']:
handler = WebHookHandler(webhook, batch=False)
handler.setLevel(logging.CRITICAL)
handlers.append(handler)
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
# qhandler.addFilter(ThreadFilter('MainThread'))
logger.addHandler(qhandler)
listener = QueueListener(
queue, *handlers, respect_handler_level=True
)
listener.start()

View File

@@ -0,0 +1,34 @@
from discord import Intents
from cmdClient.cmdClient import cmdClient
from . import patches
from .interactions import InteractionType
from .config import conf
from .sharding import shard_number, shard_count
from LionContext import LionContext
# Initialise client
owners = [int(owner) for owner in conf.bot.getlist('owners')]
intents = Intents.all()
intents.presences = False
client = cmdClient(
prefix=conf.bot['prefix'],
owners=owners,
intents=intents,
shard_id=shard_number,
shard_count=shard_count,
baseContext=LionContext
)
client.conf = conf
# TODO: Could include client id here, or app id, to avoid multiple handling.
NOOP_ID = 'NOOP'
@client.add_after_event('interaction_create')
async def handle_noop_interaction(client, interaction):
if interaction.interaction_type in (InteractionType.MESSAGE_COMPONENT, InteractionType.MODAL_SUBMIT):
if interaction.custom_id == NOOP_ID:
interaction.ack()

View File

@@ -0,0 +1,110 @@
import sys
import logging
import asyncio
from discord import AllowedMentions
from cmdClient.logger import cmd_log_handler
from utils.lib import mail, split_text
from .client import client
from .config import conf
from . import sharding
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=('[{asctime}][{levelname:^8}]' +
'[SHARD {}]'.format(sharding.shard_number) +
' {message}'),
datefmt='%d/%m | %H:%M:%S',
style='{'
)
# term_handler = logging.StreamHandler(sys.stdout)
# term_handler.setFormatter(log_fmt)
# logger.addHandler(term_handler)
# logger.setLevel(logging.INFO)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
logger.setLevel(logging.NOTSET)
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logger.addHandler(logging_handler_err)
# Define the context log format and attach it to the command logger as well
@cmd_log_handler
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
# Attach logger to client, for convenience
client.log = log

35
src/meta/sharding.py Normal file
View File

@@ -0,0 +1,35 @@
from .args import args
from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)

16
src/modules/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
this_package = 'modules'
active = [
'.sysadmin',
'.config',
'.economy',
'.reminders',
'.shop',
'.tasklist',
'.test',
]
async def setup(bot):
for ext in active:
await bot.load_extension(ext, package=this_package)

View File

@@ -0,0 +1,10 @@
import logging
from babel.translator import LocalBabel
logger = logging.getLogger(__name__)
babel = LocalBabel('config')
async def setup(bot):
from .cog import ConfigCog
await bot.add_cog(ConfigCog(bot))

30
src/modules/config/cog.py Normal file
View File

@@ -0,0 +1,30 @@
import discord
from discord import app_commands as appcmds
from discord.ext import commands as cmds
from meta import LionBot, LionContext, LionCog
from . import babel
_p = babel._p
class ConfigCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
async def cog_load(self):
...
async def cog_unload(self):
...
@cmds.hybrid_group(
name=_p('group:configure', "configure"),
)
@appcmds.guild_only
async def configure_group(self, ctx: LionContext):
"""
Bare command group, has no function.
"""
return

View File

@@ -0,0 +1,5 @@
from .cog import Economy
async def setup(bot):
await bot.add_cog(Economy(bot))

969
src/modules/economy/cog.py Normal file
View File

@@ -0,0 +1,969 @@
from typing import Optional, Union
from enum import Enum
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from psycopg import sql
from data import Registry, RowModel, RegisterEnum, ORDER, JOINTYPE, RawExpr
from data.columns import Integer, Bool, String, Column, Timestamp
from meta import LionCog, LionBot, LionContext
from meta.errors import ResponseTimedOut
from babel import LocalBabel
from core.data import CoreData
from utils.ui import LeoUI, LeoModal, Confirm, Pager
from utils.lib import error_embed, MessageArgs, utc_now
babel = LocalBabel('economy')
_, _p, _np = babel._, babel._p, babel._np
MAX_COINS = 2**16
class TransactionType(Enum):
"""
Schema
------
CREATE TYPE CoinTransactionType AS ENUM(
'REFUND',
'TRANSFER',
'SHOP_PURCHASE',
'STUDY_SESSION',
'ADMIN',
'TASKS'
);
"""
REFUND = 'REFUND',
TRANSFER = 'TRANSFER',
PURCHASE = 'SHOP_PURCHASE',
SESSION = 'STUDY_SESSION',
ADMIN = 'ADMIN',
TASKS = 'TASKS',
class AdminActionTarget(Enum):
"""
Schema
------
CREATE TYPE EconAdminTarget AS ENUM(
'ROLE',
'USER',
'GUILD'
);
"""
ROLE = 'ROLE',
USER = 'USER',
GUILD = 'GUILD',
class AdminActionType(Enum):
"""
Schema
------
CREATE TYPE EconAdminAction AS ENUM(
'SET',
'ADD'
);
"""
SET = 'SET',
ADD = 'ADD',
class EconomyData(Registry, name='economy'):
_TransactionType = RegisterEnum(TransactionType, 'CoinTransactionType')
_AdminActionTarget = RegisterEnum(AdminActionTarget, 'EconAdminTarget')
_AdminActionType = RegisterEnum(AdminActionType, 'EconAdminAction')
class Transaction(RowModel):
"""
Schema
------
CREATE TABLE coin_transactions(
transactionid SERIAL PRIMARY KEY,
transactiontype CoinTransactionType NOT NULL,
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
actorid BIGINT NOT NULL,
amount INTEGER NOT NULL,
bonus INTEGER NOT NULL,
from_account BIGINT,
to_account BIGINT,
refunds INTEGER REFERENCES coin_transactions (transactionid) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
);
CREATE INDEX coin_transaction_guilds ON coin_transactions (guildid);
"""
_tablename_ = 'coin_transactions'
transactionid = Integer(primary=True)
transactiontype: Column[TransactionType] = Column()
guildid = Integer()
actorid = Integer()
amount = Integer()
bonus = Integer()
from_account = Integer()
to_account = Integer()
refunds = Integer()
created_at = Timestamp()
@classmethod
async def execute_transaction(
cls,
transaction_type: TransactionType,
guildid: int, actorid: int,
from_account: int, to_account: int, amount: int, bonus: int = 0,
refunds: int = None
):
transaction = await cls.create(
transactiontype=transaction_type,
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
from_account=from_account, to_account=to_account,
refunds=refunds
)
if from_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=from_account
).set(coins=(CoreData.Member.coins - (amount + bonus)))
if to_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=to_account
).set(coins=(CoreData.Member.coins + (amount + bonus)))
return transaction
class ShopTransaction(RowModel):
"""
Schema
------
CREATE TABLE coin_transactions_shop(
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
itemid INTEGER NOT NULL REFERENCES shop_items (itemid) ON DELETE CASCADE
);
"""
_tablename_ = 'coin_transactions_shop'
transactionid = Integer(primary=True)
itemid = Integer()
@classmethod
async def purchase_transaction(
cls,
guildid: int, actorid: int,
userid: int, itemid: int, amount: int
):
conn = await cls._connector.get_connection()
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.PURCHASE,
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
amount=amount
)
return await cls.create(transactionid=row.transactionid, itemid=itemid)
class TaskTransaction(RowModel):
"""
Schema
------
CREATE TABLE coin_transactions_tasks(
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
count INTEGER NOT NULL
);
"""
_tablename_ = 'coin_transactions_tasks'
transactionid = Integer(primary=True)
count = Integer()
@classmethod
async def count_recent_for(cls, userid, guildid, interval='24h'):
"""
Retrieve the number of tasks rewarded in the last `interval`.
"""
T = EconomyData.Transaction
query = cls.table.select_where().with_no_adapter()
query.join(T, using=(T.transactionid.name, ), join_type=JOINTYPE.LEFT)
query.select(recent=sql.SQL("SUM({})").format(cls.count.expr))
query.where(
T.to_account == userid,
T.guildid == guildid,
T.created_at > RawExpr(sql.SQL("timezone('utc', NOW()) - INTERVAL {}").format(interval), ()),
)
result = await query
return result[0]['recent'] or 0
@classmethod
async def reward_completed(cls, userid, guildid, count, amount):
"""
Reward the specified member `amount` coins for completing `count` tasks.
"""
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
conn = await cls._connector.get_connection()
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.TASKS,
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
amount=amount
)
return await cls.create(transactionid=row.transactionid, count=count)
class SessionTransaction(RowModel):
"""
Schema
------
CREATE TABLE coin_transactions_sessions(
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
sessionid INTEGER NOT NULL REFERENCES session_history (sessionid) ON DELETE CASCADE
);
"""
_tablename_ = 'coin_transactions_sessions'
transactionid = Integer(primary=True)
sessionid = Integer()
class AdminActions(RowModel):
"""
Schema
------
CREATE TABLE economy_admin_actions(
actionid SERIAL PRIMARY KEY,
target_type EconAdminTarget NOT NULL,
action_type EconAdminAction NOT NULL,
targetid INTEGER NOT NULL,
amount INTEGER NOT NULL
);
"""
_tablename_ = 'economy_admin_actions'
actionid = Integer(primary=True)
target_type: Column[AdminActionTarget] = Column()
action_type: Column[AdminActionType] = Column()
targetid = Integer()
amount = Integer()
class AdminTransactions(RowModel):
"""
Schema
------
CREATE TABLE coin_transactions_admin_actions(
actionid INTEGER NOT NULL REFERENCES economy_admin_actions (actionid),
transactionid INTEGER NOT NULL REFERENCES coin_transactions (transactionid),
PRIMARY KEY (actionid, transactionid)
);
CREATE INDEX coin_transactions_admin_actions_transactionid ON coin_transactions_admin_actions (transactionid);
"""
_tablename_ = 'coin_transactions_admin_actions'
actionid = Integer(primary=True)
transactionid = Integer(primary=True)
class Economy(LionCog):
"""
Commands
--------
/economy balances [target:<mentionable>] [add:<int>] [set:<int>].
With no arguments, show a summary of current balances in the server.
With a target user or role, show their balance, and possibly their most recent transactions.
With a target user or role, and add or set, modify their balance. Confirm if more than 1 user is affected.
With no target user or role, apply to everyone in the guild. Confirm if more than 1 user affected.
/economy reset [target:<mentionable>]
Reset the economy system for the given target, or everyone in the guild.
Acts as an alias to `/economy balances target:target set:0
/economy history [target:<mentionable>]
Display a paged audit trail with the history of the selected member,
all the users in the selected role, or all users.
/sendcoins <user:<user>> [note:<str>]
Send coins to the specified user, with an optional note.
"""
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(EconomyData())
async def cog_load(self):
await self.data.init()
# ----- Economy group commands -----
@cmds.hybrid_group(name=_p('cmd:economy', "economy"))
@cmds.guild_only()
async def economy_group(self, ctx: LionContext):
pass
@economy_group.command(
name=_p('cmd:economy_balance', "balance"),
description=_p(
'cmd:economy_balance|desc',
"Display and modify LionCoin balance for members or roles."
)
)
@appcmds.rename(
target=_p('cmd:economy_balance|param:target', "target"),
add=_p('cmd:economy_balance|param:add', "add"),
set_to=_p('cmd:economy_balance|param:set', "set")
)
@appcmds.describe(
target=_p(
'cmd:economy_balance|param:target|desc',
"Target user or role to view or update. Use @everyone to update the entire guild."
),
add=_p(
'cmd:economy_balance|param:add|desc',
"Number of LionCoins to add to the target member's balance. May be negative to remove."
),
set_to=_p(
'cmd:economy_balance|param:set|set',
"New balance to set the target's balance to."
)
)
async def economy_balance_cmd(
self,
ctx: LionContext,
target: discord.User | discord.Member | discord.Role,
set_to: Optional[appcmds.Range[int, 0, MAX_COINS]] = None,
add: Optional[int] = None
):
t = self.bot.translator.t
cemoji = self.bot.config.emojis.getemoji('coin')
targets: list[Union[discord.User, discord.Member]]
if not ctx.guild:
# Added for the typechecker
# This is impossible from the guild_only ward
return
if not self.bot.core:
return
if not ctx.interaction:
return
if isinstance(target, discord.Role):
targets = [mem for mem in target.members if not mem.bot]
role = target
else:
targets = [target]
role = None
if role and not targets:
# Guard against provided target role having no members
# Possible chunking failed for this guild, want to explicitly inform.
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:economy_balance|error:no_target',
"There are no valid members in {role.mention}! It has a total of `0` LC."
)).format(role=target)
),
ephemeral=True
)
elif not role and target.bot:
# Guard against reading or modifying a bot account
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:economy_balance|error:target_is_bot',
"Bots cannot have coin balances!"
))
),
ephemeral=True
)
elif set_to is not None and add is not None:
# Requested operation doesn't make sense
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:economy_balance|error:args',
"You cannot simultaneously `set` and `add` member balances!"
))
),
ephemeral=True
)
elif set_to is not None or add is not None:
# Setting route
# First ensure all the targets we will be updating already have rows
# As this is one of the few operations that acts on members not already registered,
# We may need to do a mass row create operation.
targetids = set(target.id for target in targets)
if len(targets) > 1:
conn = await ctx.bot.db.get_connection()
async with conn.transaction():
# First fetch the members which currently exist
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
query.select('userid').with_no_adapter()
if 2 * len(targets) < len(ctx.guild.members):
# More efficient to fetch the targets explicitly
query.where(userid=list(targetids))
existent_rows = await query
existentids = set(r['userid'] for r in existent_rows)
# Then check if any new userids need adding, and if so create them
new_ids = targetids.difference(existentids)
if new_ids:
# We use ON CONFLICT IGNORE here in case the users already exist.
await self.bot.core.data.User.table.insert_many(
('userid',),
*((id,) for id in new_ids)
).on_conflict(ignore=True)
# TODO: Replace 0 here with the starting_coin value
await self.bot.core.data.Member.table.insert_many(
('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True)
else:
# With only one target, we can take a simpler path, and make better use of local caches.
await self.bot.core.lions.fetch(ctx.guild.id, target.id)
# Now we are certain these members have a database row
# Perform the appropriate action
if role:
affected = t(_np(
'cmd:economy_balance|embed:success|affected',
"One user was affected.",
"**{count}** users were affected.",
len(targets)
)).format(count=len(targets))
conf_affected = t(_np(
'cmd:economy_balance|confirm|affected',
"One user will be affected.",
"**{count}** users will be affected.",
len(targets)
)).format(count=len(targets))
confirm = Confirm(conf_affected)
confirm.confirm_button = t(_p(
'cmd:economy_balance|confirm|button:confirm',
"Yes, adjust balances"
))
confirm.confirm_button = t(_p(
'cmd:economy_balance|confirm|button:cancel',
"No, cancel"
))
if set_to is not None:
if role:
if role.is_default():
description = t(_p(
'cmd:economy_balance|embed:success_set|desc',
"All members of **{guild_name}** have had their "
"balance set to {coin_emoji}**{amount}**."
)).format(
guild_name=ctx.guild.name,
coin_emoji=cemoji,
amount=set_to
) + '\n' + affected
conf_description = t(_p(
'cmd:economy_balance|confirm_set|desc',
"Are you sure you want to set everyone's balance to {coin_emoji}**{amount}**?"
)).format(
coin_emoji=cemoji,
amount=set_to
) + '\n' + conf_affected
else:
description = t(_p(
'cmd:economy_balance|embed:success_set|desc',
"All members of {role_mention} have had their "
"balance set to {coin_emoji}**{amount}**."
)).format(
role_mention=role.mention,
coin_emoji=cemoji,
amount=set_to
) + '\n' + affected
conf_description = t(_p(
'cmd:economy_balance|confirm_set|desc',
"Are you sure you want to set the balance of everyone with {role_mention} "
"to {coin_emoji}**{amount}**?"
)).format(
role_mention=role.mention,
coin_emoji=cemoji,
amount=set_to
) + '\n' + conf_affected
confirm.embed.description = conf_description
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResponseTimedOut:
return
if not result:
return
else:
description = t(_p(
'cmd:economy_balance|embed:success_set|desc',
"{user_mention} now has a balance of {coin_emoji}**{amount}**."
)).format(
user_mention=target.mention,
coin_emoji=cemoji,
amount=set_to
)
await self.bot.core.data.Member.table.update_where(
guildid=ctx.guild.id, userid=list(targetids)
).set(
coins=set_to
)
else:
if role:
if role.is_default():
description = t(_p(
'cmd:economy_balance|embed:success_add|desc',
"All members of **{guild_name}** have been given "
"{coin_emoji}**{amount}**."
)).format(
guild_name=ctx.guild.name,
coin_emoji=cemoji,
amount=add
) + '\n' + affected
conf_description = t(_p(
'cmd:economy_balance|confirm_add|desc',
"Are you sure you want to add **{amount}** to everyone's balance?"
)).format(
coin_emoji=cemoji,
amount=add
) + '\n' + conf_affected
else:
description = t(_p(
'cmd:economy_balance|embed:success_add|desc',
"All members of {role_mention} have been given "
"{coin_emoji}**{amount}**."
)).format(
role_mention=role.mention,
coin_emoji=cemoji,
amount=add
) + '\n' + affected
conf_description = t(_p(
'cmd:economy_balance|confirm_add|desc',
"Are you sure you want to add {coin_emoji}**{amount}** to everyone in {role_mention}?"
)).format(
coin_emoji=cemoji,
amount=add,
role_mention=role.mention
) + '\n' + conf_affected
confirm.embed.description = conf_description
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResponseTimedOut:
return
if not result:
return
results = await self.bot.core.data.Member.table.update_where(
guildid=ctx.guild.id, userid=list(targetids)
).set(
coins=(self.bot.core.data.Member.coins + add)
)
# Single member case occurs afterwards so we can pick up the results
if not role:
description = t(_p(
'cmd:economy_balance|embed:success_add|desc',
"{user_mention} was given {coin_emoji}**{amount}**, and "
"now has a balance of {coin_emoji}**{new_amount}**."
)).format(
user_mention=target.mention,
coin_emoji=cemoji,
amount=add,
new_amount=results[0]['coins']
)
title = t(_np(
'cmd:economy_balance|embed:success|title',
"Account successfully updated.",
"Accounts successfully updated.",
len(targets)
))
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.brand_green(),
description=description,
title=title,
)
)
else:
# Viewing route
MemModel = self.bot.core.data.Member
if role:
query = MemModel.fetch_where(
(MemModel.guildid == role.guild.id) & (MemModel.coins != 0)
)
query.order_by('coins', ORDER.DESC)
if not role.is_default():
# Everyone role is handled differently for data efficiency
ids = [target.id for target in targets]
query = query.where(userid=ids)
rows = await query
name = t(_p(
'cmd:economy_balance|embed:role_lb|author',
"Balance sheet for {name}"
)).format(name=role.name if not role.is_default() else role.guild.name)
if rows:
if role.is_default():
header = t(_p(
'cmd:economy_balance|embed:role_lb|header',
"This server has a total balance of {coin_emoji}**{total}**."
)).format(
coin_emoji=cemoji,
total=sum(row.coins for row in rows)
)
else:
header = t(_p(
'cmd:economy_balance|embed:role_lb|header',
"{role_mention} has `{count}` members with non-zero balance, "
"with a total balance of {coin_emoji}**{total}**."
)).format(
count=len(targets),
role_mention=role.mention,
total=sum(row.coins for row in rows),
coin_emoji=cemoji
)
# Build the leaderboard:
lb_format = t(_p(
'cmd:economy_balance|embed:role_lb|row_format',
"`[{pos:>{numwidth}}]` | `{coins:>{coinwidth}} LC` | {mention}"
))
blocklen = 20
blocks = [rows[i:i+blocklen] for i in range(0, len(rows), blocklen)]
paged = len(blocks) > 1
pages = []
for i, block in enumerate(blocks):
lines = []
numwidth = len(str(i + len(block)))
coinwidth = len(str(max(row.coins for row in rows)))
for j, row in enumerate(block, start=i):
lines.append(
lb_format.format(
pos=j, numwidth=numwidth,
coins=row.coins, coinwidth=coinwidth,
mention=f"<@{row.userid}>"
)
)
lb_block = '\n'.join(lines)
embed = discord.Embed(
description=f"{header}\n{lb_block}"
)
embed.set_author(name=name)
if paged:
embed.set_footer(
text=t(_p(
'cmd:economy_balance|embed:role_lb|footer',
"Page {page}/{total}"
)).format(page=i+1, total=len(blocks))
)
pages.append(MessageArgs(embed=embed))
pager = Pager(pages, show_cancel=True)
await pager.run(ctx.interaction)
else:
if role.is_default():
header = t(_p(
'cmd:economy_balance|embed:role_lb|header',
"This server has a total balance of {coin_emoji}**0**."
)).format(
coin_emoji=cemoji,
)
else:
header = t(_p(
'cmd:economy_balance|embed:role_lb|header',
"The role {role_mention} has a total balance of {coin_emoji}**0**."
)).format(
role_mention=role.mention,
coin_emoji=cemoji
)
embed = discord.Embed(
colour=discord.Colour.orange(),
description=header
)
embed.set_author(name=name)
await ctx.reply(embed=embed)
else:
# If we have a single target, show their current balance, with a short transaction history.
user = targets[0]
row = await self.bot.core.data.Member.fetch(ctx.guild.id, user.id)
embed = discord.Embed(
colour=discord.Colour.orange(),
description=t(_p(
'cmd:economy_balance|embed:single|desc',
"{mention} currently owns {coin_emoji} {coins}."
)).format(
mention=user.mention,
coin_emoji=self.bot.config.emojis.getemoji('coin'),
coins=row.coins
)
).set_author(
icon_url=user.avatar,
name=t(_p(
'cmd:economy_balance|embed:single|author',
"Balance statement for {user}"
)).format(user=str(user))
)
await ctx.reply(
embed=embed
)
# TODO: Add small transaction history block when we have transaction formatter
@economy_group.command(
name=_p('cmd:economy_reset', "reset"),
description=_p(
'cmd:economy_reset|desc',
"Reset the coin balance for a target user or role. (See also \"economy balance\".)"
)
)
@appcmds.rename(
target=_p('cmd:economy_reset|param:target', "target"),
)
@appcmds.describe(
target=_p(
'cmd:economy_reset|param:target|desc',
"Target user or role to view or update. Use @everyone to reset the entire guild."
),
)
async def economy_reset_cmd(
self,
ctx: LionContext,
target: discord.User | discord.Member | discord.Role,
):
# TODO: Permission check
t = self.bot.translator.t
starting_balance = 0
coin_emoji = self.bot.config.emojis.getemoji('coin')
# Typechecker guards
if not ctx.guild:
return
if not ctx.bot.core:
return
if not ctx.interaction:
return
if isinstance(target, discord.Role):
if target.is_default():
# Confirm: Reset Guild
confirm_msg = t(_p(
'cmd:economy_reset|confirm:reset_guild|desc',
"Are you sure you want to reset the coin balance for everyone in **{guild_name}**?\n"
"*This is not reversible!*"
)).format(
guild_name=ctx.guild.name
)
confirm = Confirm(confirm_msg)
confirm.confirm_button.label = t(_p(
'cmd:economy_reset|confirm:reset_guild|button:confirm',
"Yes, reset the economy"
))
confirm.cancel_button.label = t(_p(
'cmd:economy_reset|confirm:reset_guild|button:cancel',
"Cancel reset"
))
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResponseTimedOut:
return
if result:
# Complete reset
await ctx.bot.core.data.Member.table.update_where(
guildid=ctx.guild.id,
).set(coins=starting_balance)
await ctx.reply(
embed=discord.Embed(
description=t(_p(
'cmd:economy_reset|embed:success_guild|desc',
"Everyone in **{guild_name}** has had their balance reset to {coin_emoji}**{amount}**."
)).format(
guild_name=ctx.guild.name,
coin_emoji=coin_emoji,
amount=starting_balance
)
)
)
else:
# Provided a role to reset
targets = [member for member in target.members if not member.bot]
if not targets:
# Error: No targets
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:economy_reset|error:no_target|desc',
"The role {mention} has no members to reset!"
)).format(mention=target.mention)
),
ephemeral=True
)
else:
# Confirm: Reset Role
# Include number of people affected
confirm_msg = t(_p(
'cmd:economy_reset|confirm:reset_role|desc',
"Are you sure you want to reset the balance for everyone in {mention}?\n"
"**{count}** members will be affected."
)).format(
mention=target.mention,
count=len(targets)
)
confirm = Confirm(confirm_msg)
confirm.confirm_button.label = t(_p(
'cmd:economy_reset|confirm:reset_role|button:confirm',
"Yes, complete economy reset"
))
confirm.cancel_button.label = t(_p(
'cmd:economy_reset|confirm:reset_role|button:cancel',
"Cancel"
))
try:
result = await confirm.ask(ctx.interaction, ephemeral=True)
except ResponseTimedOut:
return
if result:
# Complete reset
await ctx.bot.core.data.Member.table.update_where(
guildid=ctx.guild.id,
userid=[t.id for t in targets],
).set(coins=starting_balance)
await ctx.reply(
embed=discord.Embed(
description=t(_p(
'cmd:economy_reset|embed:success_role|desc',
"Everyone in {role_mention} has had their "
"coin balance reset to {coin_emoji}**{amount}**."
)).format(
mention=target.mention,
coin_emoji=coin_emoji,
amount=starting_balance
)
)
)
else:
# Provided an individual user.
# Reset their balance
# Do not create the member row if it does not already exist.
# TODO: Audit logging trail
await ctx.bot.core.data.Member.table.update_where(
guuildid=ctx.guild.id,
userid=target.id,
).set(coins=starting_balance)
await ctx.reply(
embed=discord.Embed(
description=t(_p(
'cmd:economy_reset|embed:success_user|desc',
"{mention}'s balance has been reset to {coin_emoji}**{amount}**."
)).format(
mention=target.mention,
coin_emoji=coin_emoji,
amount=starting_balance
)
)
)
@cmds.hybrid_command(
name=_p('cmd:send', "send"),
description=_p(
'cmd:send|desc',
"Gift the target user a certain number of LionCoins."
)
)
@appcmds.rename(
target=_p('cmd:send|param:target', "target"),
amount=_p('cmd:send|param:amount', "amount"),
note=_p('cmd:send|param:note', "note")
)
@appcmds.describe(
target=_p('cmd:send|param:target|desc', "User to send the gift to"),
amount=_p('cmd:send|param:amount|desc', "Number of coins to send"),
note=_p('cmd:send|param:note|desc', "Optional note to add to the gift.")
)
@appcmds.guild_only()
async def send_cmd(self, ctx: LionContext,
target: discord.User | discord.Member,
amount: appcmds.Range[int, 1, MAX_COINS],
note: Optional[str] = None):
"""
Send `amount` lioncoins to the provided `target`, with the optional `note` attached.
"""
if not ctx.interaction:
return
if not ctx.guild:
return
if not self.bot.core:
return
t = self.bot.translator.t
Member = self.bot.core.data.Member
target_lion = await self.bot.core.lions.fetch(ctx.guild.id, target.id)
# TODO: Add a "Send thanks" button to the DM?
# Alternative flow could be waiting until the target user presses accept
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
conn = await self.bot.db.get_connection()
async with conn.transaction():
# We do this in a transaction so that if something goes wrong,
# the coins deduction is rolled back atomicly
balance = ctx.alion.data.coins
if amount > balance:
await ctx.interaction.edit_original_response(
embed=error_embed(
t(_p(
'cmd:send|error:insufficient',
"You do not have enough lioncoins to do this!\n"
"`Current Balance:` {coin_emoji}{balance}"
)).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
balance=balance
)
),
)
return
# Transfer the coins
await ctx.alion.data.update(coins=(Member.coins - amount))
await target_lion.data.update(coins=(Member.coins + amount))
# TODO: Audit trail
# Message target
embed = discord.Embed(
title=t(_p(
'cmd:send|embed:gift|title',
"{user} sent you a gift!"
)).format(user=ctx.author.name),
description=t(_p(
'cmd:send|embed:gift|desc',
"{mention} sent you {coin_emoji}**{amount}**."
)).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
amount=amount,
mention=ctx.author.mention
),
timestamp=utc_now()
)
if note:
embed.add_field(
name="Note Attached",
value=note
)
try:
await target.send(embed=embed)
failed = False
except discord.HTTPException:
failed = True
pass
# Ack transfer
embed = discord.Embed(
colour=discord.Colour.brand_green(),
description=t(_p(
'cmd:send|embed:ack|desc',
"**{coin_emoji}{amount}** has been deducted from your balance and sent to {mention}!"
)).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
amount=amount,
mention=target.mention
)
)
if failed:
embed.description = t(_p(
'cmd:send|embed:ack|desc|error:unreachable',
"Unfortunately, I was not able to message the recipient. Perhaps they have me blocked?"
))
await ctx.interaction.edit_original_response(embed=embed)

View File

@@ -0,0 +1,16 @@
from .sysadmin import *
from .guild_admin import *
from .meta import *
from .economy import *
from .study import *
from .stats import *
from .user_config import *
from .workout import *
from .todo import *
from .topgg import *
from .reminders import *
from .renting import *
from .moderation import *
from .accountability import *
from .plugins import *
from .sponsors import *

View File

@@ -0,0 +1,477 @@
from typing import List, Dict
import datetime
import discord
import asyncio
import random
from settings import GuildSettings
from utils.lib import tick, cross
from core import Lion
from meta import client
from .lib import utc_now
from .data import accountability_members, accountability_rooms
class SlotMember:
"""
Class representing a member booked into an accountability room.
Mostly acts as an interface to the corresponding TableRow.
But also stores the discord.Member associated, and has several computed properties.
The member may be None.
"""
___slots__ = ('slotid', 'userid', 'guild')
def __init__(self, slotid, userid, guild):
self.slotid = slotid
self.userid = userid
self.guild = guild
self._member = None
@property
def key(self):
return (self.slotid, self.userid)
@property
def data(self):
return accountability_members.fetch(self.key)
@property
def member(self):
return self.guild.get_member(self.userid)
@property
def has_attended(self):
return self.data.duration > 0 or self.data.last_joined_at
class TimeSlot:
"""
Class representing an accountability slot.
"""
__slots__ = (
'guild',
'start_time',
'data',
'lobby',
'category',
'channel',
'message',
'members'
)
slots = {}
_member_overwrite = discord.PermissionOverwrite(
view_channel=True,
connect=True
)
_everyone_overwrite = discord.PermissionOverwrite(
view_channel=False,
connect=False,
speak=False
)
happy_lion = "https://media.discordapp.net/stickers/898266283559227422.png"
sad_lion = "https://media.discordapp.net/stickers/898266548421148723.png"
def __init__(self, guild, start_time, data=None):
self.guild: discord.Guild = guild
self.start_time: datetime.datetime = start_time
self.data = data
self.lobby: discord.TextChannel = None # Text channel to post the slot status
self.category: discord.CategoryChannel = None # Category to create the voice rooms in
self.channel: discord.VoiceChannel = None # Text channel associated with this time slot
self.message: discord.Message = None # Status message in lobby channel
self.members: Dict[int, SlotMember] = {} # memberid -> SlotMember
@property
def open_embed(self):
timestamp = int(self.start_time.timestamp())
embed = discord.Embed(
title="Session <t:{}:t> - <t:{}:t>".format(
timestamp, timestamp + 3600
),
colour=discord.Colour.orange(),
timestamp=self.start_time
).set_footer(
text="About to start!\nJoin the session with {}schedule book".format(client.prefix)
)
if self.members:
embed.description = "Starting <t:{}:R>.".format(timestamp)
embed.add_field(
name="Members",
value=(
', '.join('<@{}>'.format(key) for key in self.members.keys())
)
)
else:
embed.description = "No members scheduled this session!"
return embed
@property
def status_embed(self):
timestamp = int(self.start_time.timestamp())
embed = discord.Embed(
title="Session <t:{}:t> - <t:{}:t>".format(
timestamp, timestamp + 3600
),
description="Finishing <t:{}:R>.".format(timestamp + 3600),
colour=discord.Colour.orange(),
timestamp=self.start_time
).set_footer(text="Join the next session using {}schedule book".format(client.prefix))
if self.members:
classifications = {
"Attended": [],
"Studying Now": [],
"Waiting for": []
}
for memid, mem in self.members.items():
mention = '<@{}>'.format(memid)
if not mem.has_attended:
classifications["Waiting for"].append(mention)
elif mem.member in self.channel.members:
classifications["Studying Now"].append(mention)
else:
classifications["Attended"].append(mention)
all_attended = all(mem.has_attended for mem in self.members.values())
bonus_line = (
"{tick} Everyone attended, and will get a `{bonus} LC` bonus!".format(
tick=tick,
bonus=GuildSettings(self.guild.id).accountability_bonus.value
)
if all_attended else ""
)
if all_attended:
embed.set_thumbnail(url=self.happy_lion)
embed.description += "\n" + bonus_line
for field, value in classifications.items():
if value:
embed.add_field(name=field, value='\n'.join(value))
else:
embed.description = "No members scheduled this session!"
return embed
@property
def summary_embed(self):
timestamp = int(self.start_time.timestamp())
embed = discord.Embed(
title="Session <t:{}:t> - <t:{}:t>".format(
timestamp, timestamp + 3600
),
description="Finished <t:{}:R>.".format(timestamp + 3600),
colour=discord.Colour.orange(),
timestamp=self.start_time
).set_footer(text="Completed!")
if self.members:
classifications = {
"Attended": [],
"Missing": []
}
for memid, mem in sorted(self.members.items(), key=lambda mem: mem[1].data.duration, reverse=True):
mention = '<@{}>'.format(memid)
if mem.has_attended:
classifications["Attended"].append(
"{} ({}%)".format(mention, (mem.data.duration * 100) // 3600)
)
else:
classifications["Missing"].append(mention)
all_attended = all(mem.has_attended for mem in self.members.values())
bonus_line = (
"{tick} Everyone attended, and received a `{bonus} LC` bonus!".format(
tick=tick,
bonus=GuildSettings(self.guild.id).accountability_bonus.value
)
if all_attended else
"{cross} Some members missed the session, so everyone missed out on the bonus!".format(
cross=cross
)
)
if all_attended:
embed.set_thumbnail(url=self.happy_lion)
else:
embed.set_thumbnail(url=self.sad_lion)
embed.description += "\n" + bonus_line
for field, value in classifications.items():
if value:
embed.add_field(name=field, value='\n'.join(value))
else:
embed.description = "No members scheduled this session!"
return embed
def load(self, memberids: List[int] = None):
"""
Load data and update applicable caches.
"""
if not self.guild:
return self
# Load setting data
self.category = GuildSettings(self.guild.id).accountability_category.value
self.lobby = GuildSettings(self.guild.id).accountability_lobby.value
if self.data:
# Load channel
if self.data.channelid:
self.channel = self.guild.get_channel(self.data.channelid)
# Load message
if self.data.messageid and self.lobby:
self.message = discord.PartialMessage(
channel=self.lobby,
id=self.data.messageid
)
# Load members
if memberids:
self.members = {
memberid: SlotMember(self.data.slotid, memberid, self.guild)
for memberid in memberids
}
return self
async def _reload_members(self, memberids=None):
"""
Reload the timeslot members from the provided list, or data.
Also updates the channel overwrites if required.
To be used before the session has started.
"""
if self.data:
if memberids is None:
member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid)
memberids = [row.userid for row in member_rows]
self.members = members = {
memberid: SlotMember(self.data.slotid, memberid, self.guild)
for memberid in memberids
}
if self.channel:
# Check and potentially update overwrites
current_overwrites = self.channel.overwrites
overwrites = {
mem.member: self._member_overwrite
for mem in members.values()
if mem.member
}
overwrites[self.guild.default_role] = self._everyone_overwrite
if current_overwrites != overwrites:
await self.channel.edit(overwrites=overwrites)
def _refresh(self):
"""
Refresh the stored data row and reload.
"""
rows = accountability_rooms.fetch_rows_where(
guildid=self.guild.id,
start_at=self.start_time
)
self.data = rows[0] if rows else None
memberids = []
if self.data:
member_rows = accountability_members.fetch_rows_where(
slotid=self.data.slotid
)
memberids = [row.userid for row in member_rows]
self.load(memberids=memberids)
async def open(self):
"""
Open the accountability room.
Creates a new voice channel, and sends the status message.
Event logs any issues.
Adds the TimeSlot to cache.
Returns the (channelid, messageid).
"""
# Cleanup any non-existent members
for memid, mem in list(self.members.items()):
if not mem.data or not mem.member:
self.members.pop(memid)
# Calculate overwrites
overwrites = {
mem.member: self._member_overwrite
for mem in self.members.values()
}
overwrites[self.guild.default_role] = self._everyone_overwrite
# Create the channel. Log and bail if something went wrong.
if self.data and not self.channel:
try:
self.channel = await self.guild.create_voice_channel(
"Upcoming Scheduled Session",
overwrites=overwrites,
category=self.category
)
except discord.HTTPException:
GuildSettings(self.guild.id).event_log.log(
"Failed to create the scheduled session voice channel. Skipping this session.",
colour=discord.Colour.red()
)
return None
elif not self.data:
self.channel = None
# Send the inital status message. Log and bail if something goes wrong.
if not self.message:
try:
self.message = await self.lobby.send(
embed=self.open_embed
)
except discord.HTTPException:
GuildSettings(self.guild.id).event_log.log(
"Failed to post the status message in the scheduled session lobby {}.\n"
"Skipping this session.".format(self.lobby.mention),
colour=discord.Colour.red()
)
return None
if self.members:
await self.channel_notify()
return (self.channel.id if self.channel else None, self.message.id)
async def channel_notify(self, content=None):
"""
Ghost pings the session members in the lobby channel.
"""
if self.members:
content = content or "Your scheduled session has started! Please join!"
out = "{}\n\n{}".format(
content,
' '.join('<@{}>'.format(memid) for memid, mem in self.members.items() if not mem.has_attended)
)
out_msg = await self.lobby.send(out)
await out_msg.delete()
async def start(self):
"""
Start the accountability room slot.
Update the status message, and launch the DM reminder.
"""
dither = 15 * random.random()
await asyncio.sleep(dither)
if self.channel:
try:
await self.channel.edit(name="Scheduled Session Room")
await self.channel.set_permissions(self.guild.default_role, view_channel=True, connect=False)
except discord.HTTPException:
pass
asyncio.create_task(self.dm_reminder(delay=60))
try:
await self.message.edit(embed=self.status_embed)
except discord.NotFound:
try:
self.message = await self.lobby.send(
embed=self.status_embed
)
except discord.HTTPException:
self.message = None
async def dm_reminder(self, delay=60):
"""
Notifies missing members with a direct message after 1 minute.
"""
await asyncio.sleep(delay)
embed = discord.Embed(
title="The scheduled session you booked has started!",
description="Please join {}.".format(self.channel.mention),
colour=discord.Colour.orange()
).set_footer(
text=self.guild.name,
icon_url=self.guild.icon_url
)
members = (mem.member for mem in self.members.values() if not mem.has_attended)
members = (member for member in members if member)
await asyncio.gather(
*(member.send(embed=embed) for member in members),
return_exceptions=True
)
async def close(self):
"""
Delete the channel and update the status message to display a session summary.
Unloads the TimeSlot from cache.
"""
dither = 15 * random.random()
await asyncio.sleep(dither)
if self.channel:
try:
await self.channel.delete()
except discord.HTTPException:
pass
if self.message:
try:
await self.message.edit(embed=self.summary_embed)
except discord.HTTPException:
pass
# Reward members appropriately
if self.guild:
guild_settings = GuildSettings(self.guild.id)
reward = guild_settings.accountability_reward.value
if all(mem.has_attended for mem in self.members.values()):
reward += guild_settings.accountability_bonus.value
for memid in self.members:
Lion.fetch(self.guild.id, memid).addCoins(reward, bonus=True)
async def cancel(self):
"""
Cancel the slot, generally due to missing data.
Updates the message and channel if possible, removes slot from cache, and also updates data.
# TODO: Refund members
"""
if self.data:
self.data.closed_at = utc_now()
if self.channel:
try:
await self.channel.delete()
self.channel = None
except discord.HTTPException:
pass
if self.message:
try:
timestamp = int(self.start_time.timestamp())
embed = discord.Embed(
title="Session <t:{}:t> - <t:{}:t>".format(
timestamp, timestamp + 3600
),
description="Session canceled!",
colour=discord.Colour.red()
)
await self.message.edit(embed=embed)
except discord.HTTPException:
pass
async def update_status(self):
"""
Intelligently update the status message.
"""
if self.message:
if utc_now() < self.start_time:
await self.message.edit(embed=self.open_embed)
elif utc_now() < self.start_time + datetime.timedelta(hours=1):
await self.message.edit(embed=self.status_embed)
else:
await self.message.edit(embed=self.summary_embed)

View File

@@ -0,0 +1,6 @@
from .module import module
from . import data
from . import admin
from . import commands
from . import tracker

View File

@@ -0,0 +1,140 @@
import asyncio
import discord
import settings
from settings import GuildSettings, GuildSetting
from .tracker import AccountabilityGuild as AG
@GuildSettings.attach_setting
class accountability_category(settings.Channel, settings.GuildSetting):
category = "Scheduled Sessions"
attr_name = "accountability_category"
_data_column = "accountability_category"
display_name = "session_category"
desc = "Category in which to make the scheduled session rooms."
_default = None
long_desc = (
"\"Schedule session\" category channel.\n"
"Scheduled sessions will be held in voice channels created under this category."
)
_accepts = "A category channel."
_chan_type = discord.ChannelType.category
@property
def success_response(self):
if self.value:
# TODO Move this somewhere better
if self.id not in AG.cache:
AG(self.id)
return "The session category has been changed to **{}**.".format(self.value.name)
else:
return "The scheduled session system has been started in **{}**.".format(self.value.name)
else:
if self.id in AG.cache:
aguild = AG.cache.pop(self.id)
if aguild.current_slot:
asyncio.create_task(aguild.current_slot.cancel())
if aguild.upcoming_slot:
asyncio.create_task(aguild.upcoming_slot.cancel())
return "The scheduled session system has been shut down."
else:
return "The scheduled session category has been unset."
@GuildSettings.attach_setting
class accountability_lobby(settings.Channel, settings.GuildSetting):
category = "Scheduled Sessions"
attr_name = "accountability_lobby"
_data_column = attr_name
display_name = "session_lobby"
desc = "Category in which to post scheduled session notifications updates."
_default = None
long_desc = (
"Scheduled session updates will be posted here, and members will be notified in this channel.\n"
"The channel will be automatically created in the configured `session_category` if it does not exist.\n"
"Members do not need to be able to write in the channel."
)
_accepts = "Any text channel."
_chan_type = discord.ChannelType.text
async def auto_create(self):
# TODO: FUTURE
...
@GuildSettings.attach_setting
class accountability_price(settings.Integer, GuildSetting):
category = "Scheduled Sessions"
attr_name = "accountability_price"
_data_column = attr_name
display_name = "session_price"
desc = "Cost of booking a scheduled session."
_default = 500
long_desc = (
"The price of booking each one hour scheduled session slot."
)
_accepts = "An integer number of coins."
@property
def success_response(self):
return "Scheduled session slots now cost `{}` coins.".format(self.value)
@GuildSettings.attach_setting
class accountability_bonus(settings.Integer, GuildSetting):
category = "Scheduled Sessions"
attr_name = "accountability_bonus"
_data_column = attr_name
display_name = "session_bonus"
desc = "Bonus given when everyone attends a scheduled session slot."
_default = 750
long_desc = (
"The extra bonus given to each scheduled session member when everyone who booked attended the session."
)
_accepts = "An integer number of coins."
@property
def success_response(self):
return "Scheduled session members will now get `{}` coins if everyone joins.".format(self.value)
@GuildSettings.attach_setting
class accountability_reward(settings.Integer, GuildSetting):
category = "Scheduled Sessions"
attr_name = "accountability_reward"
_data_column = attr_name
display_name = "session_reward"
desc = "The individual reward given when a member attends their booked scheduled session."
_default = 500
long_desc = (
"Reward given to a member who attends a booked scheduled session."
)
_accepts = "An integer number of coins."
@property
def success_response(self):
return "Members will now get `{}` coins when they attend their scheduled session.".format(self.value)

View File

@@ -0,0 +1,596 @@
import re
import datetime
import discord
import asyncio
import contextlib
from cmdClient.checks import in_guild
from meta import client
from utils.lib import multiselect_regex, parse_ranges, prop_tabulate
from data import NOTNULL
from data.conditions import GEQ, LEQ
from .module import module
from .lib import utc_now
from .tracker import AccountabilityGuild as AGuild
from .tracker import room_lock
from .TimeSlot import SlotMember
from .data import accountability_members, accountability_member_info, accountability_rooms
hint_icon = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
def time_format(time):
diff = (time - utc_now()).total_seconds()
if diff < 0:
diffstr = "`Right Now!!`"
elif diff < 600:
diffstr = "`Very soon!!`"
elif diff < 3600:
diffstr = "`In <1 hour `"
else:
hours = round(diff / 3600)
diffstr = "`In {:>2} hour{}`".format(hours, 's' if hours > 1 else ' ')
return "{} | <t:{:.0f}:t> - <t:{:.0f}:t>".format(
diffstr,
time.timestamp(),
time.timestamp() + 3600,
)
user_locks = {} # Map userid -> ctx
@contextlib.contextmanager
def ensure_exclusive(ctx):
"""
Cancel any existing exclusive contexts for the author.
"""
old_ctx = user_locks.pop(ctx.author.id, None)
if old_ctx:
[task.cancel() for task in old_ctx.tasks]
user_locks[ctx.author.id] = ctx
try:
yield
finally:
new_ctx = user_locks.get(ctx.author.id, None)
if new_ctx and new_ctx.msg.id == ctx.msg.id:
user_locks.pop(ctx.author.id)
@module.cmd(
name="schedule",
desc="View your schedule, and get rewarded for attending scheduled sessions!",
group="Productivity",
aliases=('rooms', 'sessions')
)
@in_guild()
async def cmd_rooms(ctx):
"""
Usage``:
{prefix}schedule
{prefix}schedule book
{prefix}schedule cancel
Description:
View your schedule with `{prefix}schedule`.
Use `{prefix}schedule book` to schedule a session at a selected time..
Use `{prefix}schedule cancel` to cancel a scheduled session.
"""
lower = ctx.args.lower()
splits = lower.split()
command = splits[0] if splits else None
if not ctx.guild_settings.accountability_category.value:
return await ctx.error_reply("The scheduled session system isn't set up!")
# First grab the sessions the member is booked in
joined_rows = accountability_member_info.select_where(
userid=ctx.author.id,
start_at=GEQ(utc_now()),
_extra="ORDER BY start_at ASC"
)
if command == 'cancel':
if not joined_rows:
return await ctx.error_reply("You have no scheduled sessions to cancel!")
# Show unbooking menu
lines = [
"`[{:>2}]` | {}".format(i, time_format(row['start_at']))
for i, row in enumerate(joined_rows)
]
out_msg = await ctx.reply(
content="Please reply with the number(s) of the sessions you want to cancel. E.g. `1, 3, 5` or `1-3, 7-8`.",
embed=discord.Embed(
title="Please choose the sessions you want to cancel.",
description='\n'.join(lines),
colour=discord.Colour.orange()
).set_footer(
text=(
"All times are in your own timezone! Hover over a time to see the date."
)
)
)
await ctx.cancellable(
out_msg,
cancel_message="Cancel menu closed, no scheduled sessions were cancelled.",
timeout=70
)
def check(msg):
valid = msg.channel == ctx.ch and msg.author == ctx.author
valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c')
return valid
with ensure_exclusive(ctx):
try:
message = await ctx.client.wait_for('message', check=check, timeout=60)
except asyncio.TimeoutError:
try:
await out_msg.edit(
content=None,
embed=discord.Embed(
description="Cancel menu timed out, no scheduled sessions were cancelled.",
colour=discord.Colour.red()
)
)
await out_msg.clear_reactions()
except discord.HTTPException:
pass
return
try:
await out_msg.delete()
await message.delete()
except discord.HTTPException:
pass
if message.content.lower() == 'c':
return
to_cancel = [
joined_rows[index]
for index in parse_ranges(message.content) if index < len(joined_rows)
]
if not to_cancel:
return await ctx.error_reply("No valid sessions selected for cancellation.")
elif any(row['start_at'] < utc_now() for row in to_cancel):
return await ctx.error_reply("You can't cancel a running session!")
slotids = [row['slotid'] for row in to_cancel]
async with room_lock:
deleted = accountability_members.delete_where(
userid=ctx.author.id,
slotid=slotids
)
# Handle case where the slot has already opened
# TODO: Possible race condition if they open over the hour border? Might never cancel
for row in to_cancel:
aguild = AGuild.cache.get(row['guildid'], None)
if aguild and aguild.upcoming_slot and aguild.upcoming_slot.data:
if aguild.upcoming_slot.data.slotid in slotids:
aguild.upcoming_slot.members.pop(ctx.author.id, None)
if aguild.upcoming_slot.channel:
try:
await aguild.upcoming_slot.channel.set_permissions(
ctx.author,
overwrite=None
)
except discord.HTTPException:
pass
await aguild.upcoming_slot.update_status()
break
ctx.alion.addCoins(sum(row[2] for row in deleted))
remaining = [row for row in joined_rows if row['slotid'] not in slotids]
if not remaining:
await ctx.embed_reply("Cancelled all your upcoming scheduled sessions!")
else:
next_booked_time = min(row['start_at'] for row in remaining)
if len(to_cancel) > 1:
await ctx.embed_reply(
"Cancelled `{}` upcoming sessions!\nYour next session is at <t:{:.0f}>.".format(
len(to_cancel),
next_booked_time.timestamp()
)
)
else:
await ctx.embed_reply(
"Cancelled your session at <t:{:.0f}>!\n"
"Your next session is at <t:{:.0f}>.".format(
to_cancel[0]['start_at'].timestamp(),
next_booked_time.timestamp()
)
)
elif command == 'book':
# Show booking menu
# Get attendee count
rows = accountability_member_info.select_where(
guildid=ctx.guild.id,
userid=NOTNULL,
select_columns=(
'slotid',
'start_at',
'COUNT(*) as num'
),
_extra="GROUP BY start_at, slotid"
)
attendees = {row['start_at']: row['num'] for row in rows}
attendee_pad = max((len(str(num)) for num in attendees.values()), default=1)
# Build lines
already_joined_times = set(row['start_at'] for row in joined_rows)
start_time = utc_now().replace(minute=0, second=0, microsecond=0)
times = (
start_time + datetime.timedelta(hours=n)
for n in range(1, 25)
)
times = [
time for time in times
if time not in already_joined_times and (time - utc_now()).total_seconds() > 660
]
lines = [
"`[{num:>2}]` | `{count:>{count_pad}}` attending | {time}".format(
num=i,
count=attendees.get(time, 0), count_pad=attendee_pad,
time=time_format(time),
)
for i, time in enumerate(times)
]
# TODO: Nicer embed
# TODO: Don't allow multi bookings if the member has a bad attendance rate
out_msg = await ctx.reply(
content=(
"Please reply with the number(s) of the sessions you want to book. E.g. `1, 3, 5` or `1-3, 7-8`."
),
embed=discord.Embed(
title="Please choose the sessions you want to schedule.",
description='\n'.join(lines),
colour=discord.Colour.orange()
).set_footer(
text=(
"All times are in your own timezone! Hover over a time to see the date."
)
)
)
await ctx.cancellable(
out_msg,
cancel_message="Booking menu cancelled, no sessions were booked.",
timeout=60
)
def check(msg):
valid = msg.channel == ctx.ch and msg.author == ctx.author
valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c')
return valid
with ensure_exclusive(ctx):
try:
message = await ctx.client.wait_for('message', check=check, timeout=30)
except asyncio.TimeoutError:
try:
await out_msg.edit(
content=None,
embed=discord.Embed(
description="Booking menu timed out, no sessions were booked.",
colour=discord.Colour.red()
)
)
await out_msg.clear_reactions()
except discord.HTTPException:
pass
return
try:
await out_msg.delete()
await message.delete()
except discord.HTTPException:
pass
if message.content.lower() == 'c':
return
to_book = [
times[index]
for index in parse_ranges(message.content) if index < len(times)
]
if not to_book:
return await ctx.error_reply("No valid sessions selected.")
elif any(time < utc_now() for time in to_book):
return await ctx.error_reply("You can't book a running session!")
cost = len(to_book) * ctx.guild_settings.accountability_price.value
if cost > ctx.alion.coins:
return await ctx.error_reply(
"Sorry, booking `{}` sessions costs `{}` coins, and you only have `{}`!".format(
len(to_book),
cost,
ctx.alion.coins
)
)
# Add the member to data, creating the row if required
slot_rows = accountability_rooms.fetch_rows_where(
guildid=ctx.guild.id,
start_at=to_book
)
slotids = [row.slotid for row in slot_rows]
to_add = set(to_book).difference((row.start_at for row in slot_rows))
if to_add:
slotids.extend(row['slotid'] for row in accountability_rooms.insert_many(
*((ctx.guild.id, start_at) for start_at in to_add),
insert_keys=('guildid', 'start_at'),
))
accountability_members.insert_many(
*((slotid, ctx.author.id, ctx.guild_settings.accountability_price.value) for slotid in slotids),
insert_keys=('slotid', 'userid', 'paid')
)
# Handle case where the slot has already opened
# TODO: Fix this, doesn't always work
aguild = AGuild.cache.get(ctx.guild.id, None)
if aguild:
if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book:
slot = aguild.upcoming_slot
if not slot.data:
# Handle slot activation
slot._refresh()
channelid, messageid = await slot.open()
accountability_rooms.update_where(
{'channelid': channelid, 'messageid': messageid},
slotid=slot.data.slotid
)
else:
slot.members[ctx.author.id] = SlotMember(slot.data.slotid, ctx.author.id, ctx.guild)
# Also update the channel permissions
try:
await slot.channel.set_permissions(ctx.author, view_channel=True, connect=True)
except discord.HTTPException:
pass
await slot.update_status()
ctx.alion.addCoins(-cost)
# Ack purchase
embed = discord.Embed(
title="You have scheduled the following session{}!".format('s' if len(to_book) > 1 else ''),
description=(
"*Please attend all your scheduled sessions!*\n"
"*If you can't attend, cancel with* `{}schedule cancel`\n\n{}"
).format(
ctx.best_prefix,
'\n'.join(time_format(time) for time in to_book),
),
colour=discord.Colour.orange()
).set_footer(
text=(
"Use {prefix}schedule to see your current schedule.\n"
).format(prefix=ctx.best_prefix)
)
try:
await ctx.reply(
embed=embed,
reference=ctx.msg
)
except discord.NotFound:
await ctx.reply(embed=embed)
else:
# Show accountability room information for this user
# Accountability profile
# Author
# Special case for no past bookings, emphasis hint
# Hint on Bookings section for booking/cancelling as applicable
# Description has stats
# Footer says that all times are in their timezone
# TODO: attendance requirement shouldn't be retroactive! Add attended data column
# Attended `{}` out of `{}` booked (`{}%` attendance rate!)
# Attendance streak: `{}` days attended with no missed sessions!
# Add explanation for first time users
# Get all slots the member has ever booked
history = accountability_member_info.select_where(
userid=ctx.author.id,
# start_at=LEQ(utc_now() - datetime.timedelta(hours=1)),
start_at=LEQ(utc_now()),
select_columns=("*", "(duration > 0 OR last_joined_at IS NOT NULL) AS attended"),
_extra="ORDER BY start_at DESC"
)
if not (history or joined_rows):
# First-timer information
about = (
"You haven't scheduled any study sessions yet!\n"
"Schedule a session by typing **`{}schedule book`** and selecting "
"the hours you intend to study, "
"then attend by joining the session voice channel when it starts!\n"
"Only if everyone attends will they get the bonus of `{}` LionCoins!\n"
"Let's all do our best and keep each other accountable 🔥"
).format(
ctx.best_prefix,
ctx.guild_settings.accountability_bonus.value
)
embed = discord.Embed(
description=about,
colour=discord.Colour.orange()
)
embed.set_footer(
text="Please keep your DMs open so I can notify you when the session starts!\n",
icon_url=hint_icon
)
await ctx.reply(embed=embed)
else:
# Build description with stats
if history:
# First get the counts
attended_count = sum(row['attended'] for row in history)
total_count = len(history)
total_duration = sum(row['duration'] for row in history)
# Add current session to duration if it exists
if history[0]['last_joined_at'] and (utc_now() - history[0]['start_at']).total_seconds() < 3600:
total_duration += int((utc_now() - history[0]['last_joined_at']).total_seconds())
# Calculate the streak
timezone = ctx.alion.settings.timezone.value
streak = 0
current_streak = None
max_streak = 0
day_attended = None
date = utc_now().astimezone(timezone).replace(hour=0, minute=0, second=0, microsecond=0)
daydiff = datetime.timedelta(days=1)
i = 0
while i < len(history):
row = history[i]
i += 1
if not row['attended']:
# Not attended, streak broken
pass
elif row['start_at'] > date:
# They attended this day
day_attended = True
continue
elif day_attended is None:
# Didn't attend today, but don't break streak
day_attended = False
date -= daydiff
i -= 1
continue
elif not day_attended:
# Didn't attend the day, streak broken
date -= daydiff
i -= 1
pass
else:
# Attended the day
streak += 1
# Move window to the previous day and try the row again
date -= daydiff
day_attended = False
i -= 1
continue
max_streak = max(max_streak, streak)
if current_streak is None:
current_streak = streak
streak = 0
# Handle loop exit state, i.e. the last streak
if day_attended:
streak += 1
max_streak = max(max_streak, streak)
if current_streak is None:
current_streak = streak
# Build the stats
table = {
"Sessions": "**{}** attended out of **{}**, `{:.0f}%` attendance rate.".format(
attended_count,
total_count,
(attended_count * 100) / total_count,
),
"Time": "**{:02}:{:02}** in scheduled sessions.".format(
total_duration // 3600,
(total_duration % 3600) // 60
),
"Streak": "**{}** day{} with no missed sessions! (Longest: **{}** day{}.)".format(
current_streak,
's' if current_streak != 1 else '',
max_streak,
's' if max_streak != 1 else '',
),
}
desc = prop_tabulate(*zip(*table.items()))
else:
desc = (
"Good luck with your next session!\n"
)
# Build currently booked list
if joined_rows:
# TODO: (Future) calendar link
# Get attendee counts for currently booked sessions
rows = accountability_member_info.select_where(
slotid=[row["slotid"] for row in joined_rows],
userid=NOTNULL,
select_columns=(
'slotid',
'guildid',
'start_at',
'COUNT(*) as num'
),
_extra="GROUP BY start_at, slotid, guildid ORDER BY start_at ASC"
)
attendees = {
row['start_at']: (row['num'], row['guildid']) for row in rows
}
attendee_pad = max((len(str(num)) for num, _ in attendees.values()), default=1)
# TODO: Allow cancel to accept multiselect keys as args
show_guild = any(guildid != ctx.guild.id for _, guildid in attendees.values())
guild_map = {}
if show_guild:
for _, guildid in attendees.values():
if guildid not in guild_map:
guild = ctx.client.get_guild(guildid)
if not guild:
try:
guild = await ctx.client.fetch_guild(guildid)
except discord.HTTPException:
guild = None
guild_map[guildid] = guild
booked_list = '\n'.join(
"`{:>{}}` attendees | {} {}".format(
num,
attendee_pad,
time_format(start),
"" if not show_guild else (
"on this server" if guildid == ctx.guild.id else "in **{}**".format(
guild_map[guildid] or "Unknown"
)
)
) for start, (num, guildid) in attendees.items()
)
booked_field = (
"{}\n\n"
"*If you can't make your session, please cancel using `{}schedule cancel`!*"
).format(booked_list, ctx.best_prefix)
# Temporary footer for acclimatisation
# footer = "All times are displayed in your own timezone!"
footer = "Book another session using {}schedule book".format(ctx.best_prefix)
else:
booked_field = (
"Your schedule is empty!\n"
"Book another session using `{}schedule book`."
).format(ctx.best_prefix)
footer = "Please keep your DMs open for notifications!"
# Finally, build embed
embed = discord.Embed(
colour=discord.Colour.orange(),
description=desc,
).set_author(
name="Schedule statistics for {}".format(ctx.author.name),
icon_url=ctx.author.avatar_url
).set_footer(
text=footer,
icon_url=hint_icon
).add_field(
name="Upcoming sessions",
value=booked_field
)
# And send it!
await ctx.reply(embed=embed)
# TODO: roomadmin

View File

@@ -0,0 +1,34 @@
from data import Table, RowTable
from cachetools import TTLCache
accountability_rooms = RowTable(
'accountability_slots',
('slotid', 'channelid', 'guildid', 'start_at', 'messageid', 'closed_at'),
'slotid',
cache=TTLCache(5000, ttl=60*70),
attach_as='accountability_rooms'
)
accountability_members = RowTable(
'accountability_members',
('slotid', 'userid', 'paid', 'duration', 'last_joined_at'),
('slotid', 'userid'),
cache=TTLCache(5000, ttl=60*70)
)
accountability_member_info = Table('accountability_member_info')
accountability_open_slots = Table('accountability_open_slots')
# @accountability_member_info.save_query
# def user_streaks(userid, min_duration):
# with accountability_member_info.conn as conn:
# cursor = conn.cursor()
# with cursor:
# cursor.execute(
# """
# """
# )

View File

@@ -0,0 +1,8 @@
import datetime
def utc_now():
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)

View File

@@ -0,0 +1,4 @@
from LionModule import LionModule
module = LionModule("Accountability")

View File

@@ -0,0 +1,515 @@
import asyncio
import datetime
import collections
import traceback
import logging
import discord
from typing import Dict
from discord.utils import sleep_until
from meta import client
from utils.interactive import discord_shield
from data import NULL, NOTNULL, tables
from data.conditions import LEQ, THIS_SHARD
from settings import GuildSettings
from .TimeSlot import TimeSlot
from .lib import utc_now
from .data import accountability_rooms, accountability_members
from .module import module
voice_ignore_lock = asyncio.Lock()
room_lock = asyncio.Lock()
def locker(lock):
"""
Function decorator to wrap the function in a provided Lock
"""
def decorator(func):
async def wrapped(*args, **kwargs):
async with lock:
return await func(*args, **kwargs)
return wrapped
return decorator
class AccountabilityGuild:
__slots__ = ('guildid', 'current_slot', 'upcoming_slot')
cache: Dict[int, 'AccountabilityGuild'] = {} # Map guildid -> AccountabilityGuild
def __init__(self, guildid):
self.guildid = guildid
self.current_slot = None
self.upcoming_slot = None
self.cache[guildid] = self
@property
def guild(self):
return client.get_guild(self.guildid)
@property
def guild_settings(self):
return GuildSettings(self.guildid)
def advance(self):
self.current_slot = self.upcoming_slot
self.upcoming_slot = None
async def open_next(start_time):
"""
Open all the upcoming accountability rooms, and fire channel notify.
To be executed ~5 minutes to the hour.
"""
# Pre-fetch the new slot data, also populating the table caches
room_data = accountability_rooms.fetch_rows_where(
start_at=start_time,
guildid=THIS_SHARD
)
guild_rows = {row.guildid: row for row in room_data}
member_data = accountability_members.fetch_rows_where(
slotid=[row.slotid for row in room_data]
) if room_data else []
slot_memberids = collections.defaultdict(list)
for row in member_data:
slot_memberids[row.slotid].append(row.userid)
# Open a new slot in each accountability guild
to_update = [] # Cache of slot update data to be applied at the end
for aguild in list(AccountabilityGuild.cache.values()):
guild = aguild.guild
if guild:
# Initialise next TimeSlot
slot = TimeSlot(
guild,
start_time,
data=guild_rows.get(aguild.guildid, None)
)
slot.load(memberids=slot_memberids[slot.data.slotid] if slot.data else None)
if not slot.category:
# Log and unload guild
aguild.guild_settings.event_log.log(
"The scheduled session category couldn't be found!\n"
"Shutting down the scheduled session system in this server.\n"
"To re-activate, please reconfigure `config session_category`."
)
AccountabilityGuild.cache.pop(aguild.guildid, None)
await slot.cancel()
continue
elif not slot.lobby:
# TODO: Consider putting in TimeSlot.open().. or even better in accountability_lobby.create()
# Create a new lobby
try:
channel = await guild.create_text_channel(
name="session-lobby",
category=slot.category,
reason="Automatic creation of scheduled session lobby."
)
aguild.guild_settings.accountability_lobby.value = channel
slot.lobby = channel
except discord.HTTPException:
# Event log failure and skip session
aguild.guild_settings.event_log.log(
"Failed to create the scheduled session lobby text channel.\n"
"Please set the lobby channel manually with `config`."
)
await slot.cancel()
continue
# Event log creation
aguild.guild_settings.event_log.log(
"Automatically created a scheduled session lobby channel {}.".format(channel.mention)
)
results = await slot.open()
if results is None:
# Couldn't open the channel for some reason.
# Should already have been logged in `open`.
# Skip this session
await slot.cancel()
continue
elif slot.data:
to_update.append((results[0], results[1], slot.data.slotid))
# Time slot should now be open and ready to start
aguild.upcoming_slot = slot
else:
# Unload guild from cache
AccountabilityGuild.cache.pop(aguild.guildid, None)
# Update slot data
if to_update:
accountability_rooms.update_many(
*to_update,
set_keys=('channelid', 'messageid'),
where_keys=('slotid',)
)
async def turnover():
"""
Switchover from the current accountability rooms to the next ones.
To be executed as close as possible to the hour.
"""
now = utc_now()
# Open event lock so we don't read voice channel movement
async with voice_ignore_lock:
# Update session data for completed sessions
last_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None
]
to_update = [
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()), None, mem.slotid, mem.userid)
for slot in last_slots for mem in slot.members.values()
if mem.data and mem.data.last_joined_at
]
if to_update:
accountability_members.update_many(
*to_update,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# Close all completed rooms, update data
await asyncio.gather(*(slot.close() for slot in last_slots), return_exceptions=True)
update_slots = [slot.data.slotid for slot in last_slots if slot.data]
if update_slots:
accountability_rooms.update_where(
{'closed_at': utc_now()},
slotid=update_slots
)
# Rotate guild sessions
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
# We could break up the session starting?
# ---------- Start next session ----------
current_slots = [
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot is not None
]
slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data}
# Reload the slot members in case they cancelled from another shard
member_data = accountability_members.fetch_rows_where(
slotid=list(slotmap.keys())
) if slotmap else []
slot_memberids = {slotid: [] for slotid in slotmap}
for row in member_data:
slot_memberids[row.slotid].append(row.userid)
reload_tasks = (
slot._reload_members(memberids=slot_memberids[slotid])
for slotid, slot in slotmap.items()
)
await asyncio.gather(
*reload_tasks,
return_exceptions=True
)
# Move members of the next session over to the session channel
movement_tasks = (
mem.member.edit(
voice_channel=slot.channel,
reason="Moving to scheduled session."
)
for slot in current_slots
for mem in slot.members.values()
if mem.data and mem.member and mem.member.voice and mem.member.voice.channel != slot.channel
)
# We return exceptions here to ignore any permission issues that occur with moving members.
# It's also possible (likely) that members will move while we are moving other members
# Returning the exceptions ensures that they are explicitly ignored
await asyncio.gather(
*movement_tasks,
return_exceptions=True
)
# Update session data of all members in new channels
member_session_data = [
(0, slot.start_time, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if mem.data and mem.member and mem.member.voice and mem.member.voice.channel == slot.channel
]
if member_session_data:
accountability_members.update_many(
*member_session_data,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# Start all the current rooms
await asyncio.gather(
*(slot.start() for slot in current_slots),
return_exceptions=True
)
@client.add_after_event('voice_state_update')
async def room_watchdog(client, member, before, after):
"""
Update session data when a member joins or leaves an accountability room.
Ignores events that occur while `voice_ignore_lock` is held.
"""
if not voice_ignore_lock.locked() and before.channel != after.channel:
aguild = AccountabilityGuild.cache.get(member.guild.id)
if aguild and aguild.current_slot and aguild.current_slot.channel:
slot = aguild.current_slot
if member.id in slot.members:
if after.channel and after.channel.id != slot.channel.id:
# Summon them back!
asyncio.create_task(member.edit(voice_channel=slot.channel))
slot_member = slot.members[member.id]
data = slot_member.data
if before.channel and before.channel.id == slot.channel.id:
# Left accountability room
with data.batch_update():
data.duration += int((utc_now() - data.last_joined_at).total_seconds())
data.last_joined_at = None
await slot.update_status()
elif after.channel and after.channel.id == slot.channel.id:
# Joined accountability room
with data.batch_update():
data.last_joined_at = utc_now()
await slot.update_status()
async def _accountability_loop():
"""
Runloop in charge of executing the room update tasks at the correct times.
"""
# Wait until ready
while not client.is_ready():
await asyncio.sleep(0.1)
# Calculate starting next_time
# Assume the resume logic has taken care of all events/tasks before current_time
now = utc_now()
if now.minute < 55:
next_time = now.replace(minute=55, second=0, microsecond=0)
else:
next_time = now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
# Executor loop
while True:
# TODO: (FUTURE) handle cases where we actually execute much late than expected
await sleep_until(next_time)
if next_time.minute == 55:
next_time = next_time + datetime.timedelta(minutes=5)
# Open next sessions
try:
async with room_lock:
await open_next(next_time)
except Exception:
# Unknown exception. Catch it so the loop doesn't die.
client.log(
"Error while opening new scheduled sessions! "
"Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="ACCOUNTABILITY_LOOP",
level=logging.ERROR
)
elif next_time.minute == 0:
# Start new sessions
try:
async with room_lock:
await turnover()
except Exception:
# Unknown exception. Catch it so the loop doesn't die.
client.log(
"Error while starting scheduled sessions! "
"Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="ACCOUNTABILITY_LOOP",
level=logging.ERROR
)
next_time = next_time + datetime.timedelta(minutes=55)
async def _accountability_system_resume():
"""
Logic for starting the accountability system from cold.
Essentially, session and state resume logic.
"""
now = utc_now()
# Fetch the open room data, only takes into account currently running sessions.
# May include sessions that were never opened, or opened but never started
# Does not include sessions that were opened that start on the next hour
open_room_data = accountability_rooms.fetch_rows_where(
closed_at=NULL,
start_at=LEQ(now),
guildid=THIS_SHARD,
_extra="ORDER BY start_at ASC"
)
if open_room_data:
# Extract member data of these rows
member_data = accountability_members.fetch_rows_where(
slotid=[row.slotid for row in open_room_data]
)
slot_members = collections.defaultdict(list)
for row in member_data:
slot_members[row.slotid].append(row)
# Filter these into expired rooms and current rooms
expired_room_data = []
current_room_data = []
for row in open_room_data:
if row.start_at + datetime.timedelta(hours=1) < now:
expired_room_data.append(row)
else:
current_room_data.append(row)
session_updates = []
# TODO URGENT: Batch room updates here
# Expire the expired rooms
for row in expired_room_data:
if row.channelid is None or row.messageid is None:
# TODO refunds here
# If the rooms were never opened, close them and skip
row.closed_at = now
else:
# If the rooms were opened and maybe started, make optimistic guesses on session data and close.
session_end = row.start_at + datetime.timedelta(hours=1)
session_updates.extend(
(mow.duration + int((session_end - mow.last_joined_at).total_seconds()),
None, mow.slotid, mow.userid)
for mow in slot_members[row.slotid] if mow.last_joined_at
)
if client.get_guild(row.guildid):
slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load(
memberids=[mow.userid for mow in slot_members[row.slotid]]
)
try:
await slot.close()
except discord.HTTPException:
pass
row.closed_at = now
# Load the in-progress room data
if current_room_data:
async with voice_ignore_lock:
current_hour = now.replace(minute=0, second=0, microsecond=0)
await open_next(current_hour)
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
current_slots = [
aguild.current_slot
for aguild in AccountabilityGuild.cache.values()
if aguild.current_slot
]
session_updates.extend(
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()),
None, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if mem.data.last_joined_at and mem.member not in slot.channel.members
)
session_updates.extend(
(mem.data.duration,
now, mem.slotid, mem.userid)
for slot in current_slots
for mem in slot.members.values()
if not mem.data.last_joined_at and mem.member in slot.channel.members
)
if session_updates:
accountability_members.update_many(
*session_updates,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
await asyncio.gather(
*(aguild.current_slot.start()
for aguild in AccountabilityGuild.cache.values() if aguild.current_slot)
)
else:
if session_updates:
accountability_members.update_many(
*session_updates,
set_keys=('duration', 'last_joined_at'),
where_keys=('slotid', 'userid'),
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
)
# If we are in the last five minutes of the hour, open new rooms.
# Note that these may already have been opened, or they may not have been.
if now.minute >= 55:
await open_next(
now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
)
@module.launch_task
async def launch_accountability_system(client):
"""
Launcher for the accountability system.
Resumes saved sessions, and starts the accountability loop.
"""
# Load the AccountabilityGuild cache
guilds = tables.guild_config.fetch_rows_where(
accountability_category=NOTNULL,
guildid=THIS_SHARD
)
# Further filter out any guilds that we aren't in
[AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)]
await _accountability_system_resume()
asyncio.create_task(_accountability_loop())
async def unload_accountability(client):
"""
Save the current sessions and cancel the runloop in preparation for client shutdown.
"""
...
@client.add_after_event('member_join')
async def restore_accountability(client, member):
"""
Restore accountability channel permissions when a member rejoins the server, if applicable.
"""
aguild = AccountabilityGuild.cache.get(member.guild.id, None)
if aguild:
if aguild.current_slot and member.id in aguild.current_slot.members:
# Restore member permission for current slot
slot = aguild.current_slot
if slot.channel:
asyncio.create_task(discord_shield(
slot.channel.set_permissions(
member,
overwrite=slot._member_overwrite
)
))
if aguild.upcoming_slot and member.id in aguild.upcoming_slot.members:
slot = aguild.upcoming_slot
if slot.channel:
asyncio.create_task(discord_shield(
slot.channel.set_permissions(
member,
overwrite=slot._member_overwrite
)
))

View File

@@ -0,0 +1,7 @@
from .module import module
from . import guild_config
from . import statreset
from . import new_members
from . import reaction_roles
from . import economy

View File

@@ -0,0 +1,3 @@
from ..module import module
from . import set_coins

View File

@@ -0,0 +1,104 @@
import discord
import datetime
from wards import guild_admin
from settings import GuildSettings
from core import Lion
from ..module import module
POSTGRES_INT_MAX = 2147483647
@module.cmd(
"set_coins",
group="Guild Admin",
desc="Set coins on a member."
)
@guild_admin()
async def cmd_set(ctx):
"""
Usage``:
{prefix}set_coins <user mention> <amount>
Description:
Sets the given number of coins on the mentioned user.
If a number greater than 0 is mentioned, will add coins.
If a number less than 0 is mentioned, will remove coins.
Note: LionCoins on a member cannot be negative.
Example:
{prefix}set_coins {ctx.author.mention} 100
{prefix}set_coins {ctx.author.mention} -100
"""
# Extract target and amount
# Handle a slightly more flexible input than stated
splits = ctx.args.split()
digits = [isNumber(split) for split in splits[:2]]
mentions = ctx.msg.mentions
if len(splits) < 2 or not any(digits) or not (all(digits) or mentions):
return await _send_usage(ctx)
if all(digits):
# Both are digits, hopefully one is a member id, and one is an amount.
target, amount = ctx.guild.get_member(int(splits[0])), int(splits[1])
if not target:
amount, target = int(splits[0]), ctx.guild.get_member(int(splits[1]))
if not target:
return await _send_usage(ctx)
elif digits[0]:
amount, target = int(splits[0]), mentions[0]
elif digits[1]:
target, amount = mentions[0], int(splits[1])
# Fetch the associated lion
target_lion = Lion.fetch(ctx.guild.id, target.id)
# Check sanity conditions
if target == ctx.client.user:
return await ctx.embed_reply("Thanks, but Ari looks after all my needs!")
if target.bot:
return await ctx.embed_reply("We are still waiting for {} to open an account.".format(target.mention))
# Finally, send the amount and the ack message
# Postgres `coins` column is `integer`, sanity check postgres int limits - which are smalled than python int range
target_coins_to_set = target_lion.coins + amount
if target_coins_to_set >= 0 and target_coins_to_set <= POSTGRES_INT_MAX:
target_lion.addCoins(amount)
elif target_coins_to_set < 0:
target_coins_to_set = -target_lion.coins # Coins cannot go -ve, cap to 0
target_lion.addCoins(target_coins_to_set)
target_coins_to_set = 0
else:
return await ctx.embed_reply("Member coins cannot be more than {}".format(POSTGRES_INT_MAX))
embed = discord.Embed(
title="Funds Set",
description="You have set LionCoins on {} to **{}**!".format(target.mention,target_coins_to_set),
colour=discord.Colour.orange(),
timestamp=datetime.datetime.utcnow()
).set_footer(text=str(ctx.author), icon_url=ctx.author.avatar_url)
await ctx.reply(embed=embed, reference=ctx.msg)
GuildSettings(ctx.guild.id).event_log.log(
"{} set {}'s LionCoins to`{}`.".format(
ctx.author.mention,
target.mention,
target_coins_to_set
),
title="Funds Set"
)
def isNumber(var):
try:
return isinstance(int(var), int)
except:
return False
async def _send_usage(ctx):
return await ctx.error_reply(
"**Usage:** `{prefix}set_coins <mention> <amount>`\n"
"**Example:**\n"
" {prefix}set_coins {ctx.author.mention} 100\n"
" {prefix}set_coins {ctx.author.mention} -100".format(
prefix=ctx.best_prefix,
ctx=ctx
)
)

View File

@@ -0,0 +1,163 @@
import difflib
import discord
from cmdClient.lib import SafeCancellation
from wards import guild_admin, guild_moderator
from settings import UserInputError, GuildSettings
from utils.lib import prop_tabulate
import utils.ctx_addons # noqa
from .module import module
# Pages of configuration categories to display
cat_pages = {
'Administration': ('Meta', 'Guild Roles', 'New Members'),
'Moderation': ('Moderation', 'Video Channels'),
'Productivity': ('Study Tracking', 'TODO List', 'Workout'),
'Study Rooms': ('Rented Rooms', 'Scheduled Sessions'),
}
# Descriptions of each configuration category
descriptions = {
}
@module.cmd("config",
desc="View and modify the server settings.",
flags=('add', 'remove'),
group="Guild Configuration")
@guild_moderator()
async def cmd_config(ctx, flags):
"""
Usage``:
{prefix}config
{prefix}config info
{prefix}config <setting>
{prefix}config <setting> <value>
Description:
Display the server configuration panel, and view/modify the server settings.
Use `{prefix}config` to see the settings with their current values, or `{prefix}config info` to \
show brief descriptions instead.
Use `{prefix}config <setting>` (e.g. `{prefix}config event_log`) to view a more detailed description for each setting, \
including the possible values.
Finally, use `{prefix}config <setting> <value>` to set the setting to the given value.
To unset a setting, or set it to the default, use `{prefix}config <setting> None`.
Additional usage for settings which accept a list of values:
`{prefix}config <setting> <value1>, <value2>, ...`
`{prefix}config <setting> --add <value1>, <value2>, ...`
`{prefix}config <setting> --remove <value1>, <value2>, ...`
Note that the first form *overwrites* the setting completely,\
while the second two will only *add* and *remove* values, respectively.
Examples``:
{prefix}config event_log
{prefix}config event_log {ctx.ch.name}
{prefix}config autoroles Member, Level 0, Level 10
{prefix}config autoroles --remove Level 10
"""
# Cache and map some info for faster access
setting_displaynames = {setting.display_name.lower(): setting for setting in GuildSettings.settings.values()}
if not ctx.args or ctx.args.lower() in ('info', 'help'):
# Fill the setting cats
cats = {}
for setting in GuildSettings.settings.values():
cat = cats.get(setting.category, [])
cat.append(setting)
cats[setting.category] = cat
# Format the cats
sections = {}
for catname, cat in cats.items():
catprops = {
setting.display_name: setting.get(ctx.guild.id).summary if not ctx.args else setting.desc
for setting in cat
}
# TODO: Add cat description here
sections[catname] = prop_tabulate(*zip(*catprops.items()))
# Put the cats on the correct pages
pages = []
for page_name, cat_names in cat_pages.items():
page = {
cat_name: sections[cat_name] for cat_name in cat_names if cat_name in sections
}
if page:
embed = discord.Embed(
colour=discord.Colour.orange(),
title=page_name,
description=(
"View brief setting descriptions with `{prefix}config info`.\n"
"Use e.g. `{prefix}config event_log` to see more details.\n"
"Modify a setting with e.g. `{prefix}config event_log {ctx.ch.name}`.\n"
"See the [Online Tutorial]({tutorial}) for a complete setup guide.".format(
prefix=ctx.best_prefix,
ctx=ctx,
tutorial="https://discord.studylions.com/tutorial"
)
)
)
for name, value in page.items():
embed.add_field(name=name, value=value, inline=False)
pages.append(embed)
if len(pages) > 1:
[
embed.set_footer(text="Page {} of {}".format(i+1, len(pages)))
for i, embed in enumerate(pages)
]
await ctx.pager(pages)
elif pages:
await ctx.reply(embed=pages[0])
else:
await ctx.reply("No configuration options set up yet!")
else:
# Some args were given
parts = ctx.args.split(maxsplit=1)
name = parts[0]
setting = setting_displaynames.get(name.lower(), None)
if setting is None:
matches = difflib.get_close_matches(name, setting_displaynames.keys(), n=2)
match = "`{}`".format('` or `'.join(matches)) if matches else None
return await ctx.error_reply(
"Couldn't find a setting called `{}`!\n"
"{}"
"Use `{}config info` to see all the server settings.".format(
name,
"Maybe you meant {}?\n".format(match) if match else "",
ctx.best_prefix
)
)
if len(parts) == 1 and not ctx.msg.attachments:
# config <setting>
# View config embed for provided setting
await setting.get(ctx.guild.id).widget(ctx, flags=flags)
else:
# config <setting> <value>
# Ignoring the write ward currently and just enforcing admin
# Check the write ward
# if not await setting.write_ward.run(ctx):
# raise SafeCancellation(setting.write_ward.msg)
if not await guild_admin.run(ctx):
raise SafeCancellation("You need to be a server admin to modify settings!")
# Attempt to set config setting
try:
parsed = await setting.parse(ctx.guild.id, ctx, parts[1] if len(parts) > 1 else '')
parsed.write(add_only=flags['add'], remove_only=flags['remove'])
except UserInputError as e:
await ctx.reply(embed=discord.Embed(
description="{} {}".format('', e.msg),
colour=discord.Colour.red()
))
else:
await ctx.reply(embed=discord.Embed(
description="{} {}".format('', setting.get(ctx.guild.id).success_response),
colour=discord.Colour.green()
))

View File

@@ -0,0 +1,4 @@
from LionModule import LionModule
module = LionModule("Guild_Admin")

View File

@@ -0,0 +1,3 @@
from . import settings
from . import greetings
from . import roles

View File

@@ -0,0 +1,6 @@
from data import Table, RowTable
autoroles = Table('autoroles')
bot_autoroles = Table('bot_autoroles')
past_member_roles = Table('past_member_roles')

View File

@@ -0,0 +1,29 @@
import discord
from LionContext import LionContext as Context
from meta import client
from .settings import greeting_message, greeting_channel, returning_message
@client.add_after_event('member_join')
async def send_greetings(client, member):
guild = member.guild
returning = bool(client.data.lions.fetch((guild.id, member.id)))
# Handle greeting message
channel = greeting_channel.get(guild.id).value
if channel is not None:
if channel == greeting_channel.DMCHANNEL:
channel = member
ctx = Context(client, guild=guild, author=member)
if returning:
args = returning_message.get(guild.id).args(ctx)
else:
args = greeting_message.get(guild.id).args(ctx)
try:
await channel.send(**args)
except discord.HTTPException:
pass

View File

@@ -0,0 +1,115 @@
import asyncio
import discord
from collections import defaultdict
from meta import client
from core import Lion
from settings import GuildSettings
from .settings import autoroles, bot_autoroles, role_persistence
from .data import past_member_roles
# Locks to avoid storing the roles while adding them
# The locking is cautious, leaving data unchanged upon collision
locks = defaultdict(asyncio.Lock)
@client.add_after_event('member_join')
async def join_role_tracker(client, member):
"""
Add autoroles or saved roles as needed.
"""
guild = member.guild
if not guild.me.guild_permissions.manage_roles:
# We can't manage the roles here, don't try to give/restore the member roles
return
async with locks[(guild.id, member.id)]:
if role_persistence.get(guild.id).value and client.data.lions.fetch((guild.id, member.id)):
# Lookup stored roles
role_rows = past_member_roles.select_where(
guildid=guild.id,
userid=member.id
)
# Identify roles from roleids
roles = (guild.get_role(row['roleid']) for row in role_rows)
# Remove non-existent roles
roles = (role for role in roles if role is not None)
# Remove roles the client can't add
roles = [role for role in roles if role < guild.me.top_role]
if roles:
try:
await member.add_roles(
*roles,
reason="Restoring saved roles.",
)
except discord.HTTPException:
# This shouldn't ususally happen, but there are valid cases where it can
# E.g. the user left while we were restoring their roles
pass
# Event log!
GuildSettings(guild.id).event_log.log(
"Restored the following roles for returning member {}:\n{}".format(
member.mention,
', '.join(role.mention for role in roles)
),
title="Saved roles restored"
)
else:
# Add autoroles
roles = bot_autoroles.get(guild.id).value if member.bot else autoroles.get(guild.id).value
# Remove roles the client can't add
roles = [role for role in roles if role < guild.me.top_role]
if roles:
try:
await member.add_roles(
*roles,
reason="Adding autoroles.",
)
except discord.HTTPException:
# This shouldn't ususally happen, but there are valid cases where it can
# E.g. the user left while we were adding autoroles
pass
# Event log!
GuildSettings(guild.id).event_log.log(
"Gave {} the guild autoroles:\n{}".format(
member.mention,
', '.join(role.mention for role in roles)
),
titles="Autoroles added"
)
@client.add_after_event('member_remove')
async def left_role_tracker(client, member):
"""
Delete and re-store member roles when they leave the server.
"""
if (member.guild.id, member.id) in locks and locks[(member.guild.id, member.id)].locked():
# Currently processing a join event
# Which means the member left while we were adding their roles
# Cautiously return, not modifying the saved role data
return
# Delete existing member roles for this user
# NOTE: Not concurrency-safe
past_member_roles.delete_where(
guildid=member.guild.id,
userid=member.id,
)
if role_persistence.get(member.guild.id).value:
# Make sure the user has an associated lion, so we can detect when they rejoin
Lion.fetch(member.guild.id, member.id)
# Then insert the current member roles
values = [
(member.guild.id, member.id, role.id)
for role in member.roles
if not role.is_bot_managed() and not role.is_integration() and not role.is_default()
]
if values:
past_member_roles.insert_many(
*values,
insert_keys=('guildid', 'userid', 'roleid')
)

View File

@@ -0,0 +1,303 @@
import datetime
import discord
import settings
from settings import GuildSettings, GuildSetting
import settings.setting_types as stypes
from wards import guild_admin
from .data import autoroles, bot_autoroles
@GuildSettings.attach_setting
class greeting_channel(stypes.Channel, GuildSetting):
"""
Setting describing the destination of the greeting message.
Extended to support the following special values, with input and output supported.
Data `None` corresponds to `Off`.
Data `1` corresponds to `DM`.
"""
DMCHANNEL = object()
category = "New Members"
attr_name = 'greeting_channel'
_data_column = 'greeting_channel'
display_name = "welcome_channel"
desc = "Channel to send the welcome message in"
long_desc = (
"Channel to post the `welcome_message` in when a new user joins the server. "
"Accepts `DM` to indicate the welcome should be sent via direct message."
)
_accepts = (
"Text Channel name/id/mention, or `DM`, or `None` to disable."
)
_chan_type = discord.ChannelType.text
@classmethod
def _data_to_value(cls, id, data, **kwargs):
if data is None:
return None
elif data == 1:
return cls.DMCHANNEL
else:
return super()._data_to_value(id, data, **kwargs)
@classmethod
def _data_from_value(cls, id, value, **kwargs):
if value is None:
return None
elif value == cls.DMCHANNEL:
return 1
else:
return super()._data_from_value(id, value, **kwargs)
@classmethod
async def _parse_userstr(cls, ctx, id, userstr, **kwargs):
lower = userstr.lower()
if lower in ('0', 'none', 'off'):
return None
elif lower == 'dm':
return 1
else:
return await super()._parse_userstr(ctx, id, userstr, **kwargs)
@classmethod
def _format_data(cls, id, data, **kwargs):
if data is None:
return "Off"
elif data == 1:
return "DM"
else:
return "<#{}>".format(data)
@property
def success_response(self):
value = self.value
if not value:
return "Welcome messages are disabled."
elif value == self.DMCHANNEL:
return "Welcome messages will be sent via direct message."
else:
return "Welcome messages will be posted in {}".format(self.formatted)
@GuildSettings.attach_setting
class greeting_message(stypes.Message, GuildSetting):
category = "New Members"
attr_name = 'greeting_message'
_data_column = 'greeting_message'
display_name = 'welcome_message'
desc = "Welcome message sent to welcome new members."
long_desc = (
"Message to send to the configured `welcome_channel` when a member joins the server for the first time."
)
_default = r"""
{
"embed": {
"title": "Welcome!",
"thumbnail": {"url": "{guild_icon}"},
"description": "Hi {mention}!\nWelcome to **{guild_name}**! You are the **{member_count}**th member.\nThere are currently **{studying_count}** people studying.\nGood luck and stay productive!",
"color": 15695665
}
}
"""
_substitution_desc = {
'{mention}': "Mention the new member.",
'{user_name}': "Username of the new member.",
'{user_avatar}': "Avatar of the new member.",
'{guild_name}': "Name of this server.",
'{guild_icon}': "Server icon url.",
'{member_count}': "Number of members in the server.",
'{studying_count}': "Number of current voice channel members.",
}
def substitution_keys(self, ctx, **kwargs):
return {
'{mention}': ctx.author.mention,
'{user_name}': ctx.author.name,
'{user_avatar}': str(ctx.author.avatar_url),
'{guild_name}': ctx.guild.name,
'{guild_icon}': str(ctx.guild.icon_url),
'{member_count}': str(len(ctx.guild.members)),
'{studying_count}': str(len([member for ch in ctx.guild.voice_channels for member in ch.members]))
}
@property
def success_response(self):
return "The welcome message has been set!"
@GuildSettings.attach_setting
class returning_message(stypes.Message, GuildSetting):
category = "New Members"
attr_name = 'returning_message'
_data_column = 'returning_message'
display_name = 'returning_message'
desc = "Welcome message sent to returning members."
long_desc = (
"Message to send to the configured `welcome_channel` when a member returns to the server."
)
_default = r"""
{
"embed": {
"title": "Welcome Back {user_name}!",
"thumbnail": {"url": "{guild_icon}"},
"description": "Welcome back to **{guild_name}**!\nYou last studied with us <t:{last_time}:R>.\nThere are currently **{studying_count}** people studying.\nGood luck and stay productive!",
"color": 15695665
}
}
"""
_substitution_desc = {
'{mention}': "Mention the returning member.",
'{user_name}': "Username of the member.",
'{user_avatar}': "Avatar of the member.",
'{guild_name}': "Name of this server.",
'{guild_icon}': "Server icon url.",
'{member_count}': "Number of members in the server.",
'{studying_count}': "Number of current voice channel members.",
'{last_time}': "Unix timestamp of the last time the member studied.",
}
def substitution_keys(self, ctx, **kwargs):
return {
'{mention}': ctx.author.mention,
'{user_name}': ctx.author.name,
'{user_avatar}': str(ctx.author.avatar_url),
'{guild_name}': ctx.guild.name,
'{guild_icon}': str(ctx.guild.icon_url),
'{member_count}': str(len(ctx.guild.members)),
'{studying_count}': str(len([member for ch in ctx.guild.voice_channels for member in ch.members])),
'{last_time}': int(ctx.alion.data._timestamp.replace(tzinfo=datetime.timezone.utc).timestamp()),
}
@property
def success_response(self):
return "The returning message has been set!"
@GuildSettings.attach_setting
class starting_funds(stypes.Integer, GuildSetting):
category = "New Members"
attr_name = 'starting_funds'
_data_column = 'starting_funds'
display_name = 'starting_funds'
desc = "Coins given when a user first joins."
long_desc = (
"Members will be given this number of coins the first time they join the server."
)
_default = 1000
@property
def success_response(self):
return "Members will be given `{}` coins when they first join the server.".format(self.formatted)
@GuildSettings.attach_setting
class autoroles(stypes.RoleList, settings.ListData, settings.Setting):
category = "New Members"
write_ward = guild_admin
attr_name = 'autoroles'
_table_interface = autoroles
_id_column = 'guildid'
_data_column = 'roleid'
display_name = "autoroles"
desc = "Roles to give automatically to new members."
_force_unique = True
long_desc = (
"These roles will be given automatically to users when they join the server. "
"If `role_persistence` is enabled, the roles will only be given the first time a user joins the server."
)
# Flat cache, no need to expire
_cache = {}
@property
def success_response(self):
if self.value:
return "New members will be given the following roles:\n{}".format(self.formatted)
else:
return "New members will not automatically be given any roles."
@GuildSettings.attach_setting
class bot_autoroles(stypes.RoleList, settings.ListData, settings.Setting):
category = "New Members"
write_ward = guild_admin
attr_name = 'bot_autoroles'
_table_interface = bot_autoroles
_id_column = 'guildid'
_data_column = 'roleid'
display_name = "bot_autoroles"
desc = "Roles to give automatically to new bots."
_force_unique = True
long_desc = (
"These roles will be given automatically to bots when they join the server. "
"If `role_persistence` is enabled, the roles will only be given the first time a bot joins the server."
)
# Flat cache, no need to expire
_cache = {}
@property
def success_response(self):
if self.value:
return "New bots will be given the following roles:\n{}".format(self.formatted)
else:
return "New bots will not automatically be given any roles."
@GuildSettings.attach_setting
class role_persistence(stypes.Boolean, GuildSetting):
category = "New Members"
attr_name = "role_persistence"
_data_column = 'persist_roles'
display_name = "role_persistence"
desc = "Whether to remember member roles when they leave the server."
_outputs = {True: "Enabled", False: "Disabled"}
_default = True
long_desc = (
"When enabled, restores member roles when they rejoin the server.\n"
"This enables profile roles and purchased roles, such as field of study and colour roles, "
"as well as moderation roles, "
"such as the studyban and mute roles, to persist even when a member leaves and rejoins.\n"
"Note: Members who leave while this is disabled will not have their roles restored."
)
@property
def success_response(self):
if self.value:
return "Roles will now be restored when a member rejoins."
else:
return "Member roles will no longer be saved or restored."

View File

@@ -0,0 +1,6 @@
from .module import module
from . import data
from . import settings
from . import tracker
from . import command

View File

@@ -0,0 +1,943 @@
import asyncio
import discord
from discord import PartialEmoji
from cmdClient.lib import ResponseTimedOut, UserCancelled
from wards import guild_admin
from settings import UserInputError
from utils.lib import tick, cross
from .module import module
from .tracker import ReactionRoleMessage
from .data import reaction_role_reactions, reaction_role_messages
from . import settings
example_emoji = "🧮"
example_str = "🧮 mathematics, 🫀 biology, 💻 computer science, 🖼️ design, 🩺 medicine"
def _parse_messageref(ctx):
"""
Parse a message reference from the context message and return it.
Removes the parsed string from `ctx.args` if applicable.
Supports the following reference types, in precedence order:
- A Discord message reply reference.
- A message link.
- A message id.
Returns: (channelid, messageid)
`messageid` will be `None` if a valid reference was not found.
`channelid` will be `None` if the message was provided by pure id.
"""
target_id = None
target_chid = None
if ctx.msg.reference:
# True message reference extract message and return
target_id = ctx.msg.reference.message_id
target_chid = ctx.msg.reference.channel_id
elif ctx.args:
# Parse the first word of the message arguments
splits = ctx.args.split(maxsplit=1)
maybe_target = splits[0]
# Expect a message id or message link
if maybe_target.isdigit():
# Assume it is a message id
target_id = int(maybe_target)
elif '/' in maybe_target:
# Assume it is a link
# Split out the channelid and messageid, if possible
link_splits = maybe_target.rsplit('/', maxsplit=2)
if len(link_splits) > 1 and link_splits[-1].isdigit() and link_splits[-2].isdigit():
target_id = int(link_splits[-1])
target_chid = int(link_splits[-2])
# If we found a target id, truncate the arguments
if target_id is not None:
if len(splits) > 1:
ctx.args = splits[1].strip()
else:
ctx.args = ""
else:
# Last-ditch attempt, see if the argument could be a stored reaction
maybe_emoji = maybe_target.strip(',')
guild_message_rows = reaction_role_messages.fetch_rows_where(guildid=ctx.guild.id)
messages = [ReactionRoleMessage.fetch(row.messageid) for row in guild_message_rows]
emojis = {reaction.emoji: message for message in messages for reaction in message.reactions}
emoji_name_map = {emoji.name.lower(): emoji for emoji in emojis}
emoji_id_map = {emoji.id: emoji for emoji in emojis if emoji.id}
result = _parse_emoji(maybe_emoji, emoji_name_map, emoji_id_map)
if result and result in emojis:
message = emojis[result]
target_id = message.messageid
target_chid = message.data.channelid
# Return the message reference
return (target_chid, target_id)
def _parse_emoji(emoji_str, name_map, id_map):
"""
Extract a PartialEmoji from a user provided emoji string, given the accepted raw names and ids.
"""
emoji = None
if len(emoji_str) < 10 and all(ord(char) >= 256 for char in emoji_str):
# The string is pure unicode, we assume built in emoji
emoji = PartialEmoji(name=emoji_str)
elif emoji_str.lower() in name_map:
emoji = name_map[emoji_str.lower()]
elif emoji_str.isdigit() and int(emoji_str) in id_map:
emoji = id_map[int(emoji_str)]
else:
# Attempt to parse as custom emoji
# Accept custom emoji provided in the full form
emoji_split = emoji_str.strip('<>:').split(':')
if len(emoji_split) in (2, 3) and emoji_split[-1].isdigit():
emoji_id = int(emoji_split[-1])
emoji_name = emoji_split[-2]
emoji_animated = emoji_split[0] == 'a'
emoji = PartialEmoji(
name=emoji_name,
id=emoji_id,
animated=emoji_animated
)
return emoji
async def reaction_ask(ctx, question, timeout=120, timeout_msg=None, cancel_msg=None):
"""
Asks the author the provided question in an embed, and provides check/cross reactions for answering.
"""
embed = discord.Embed(
colour=discord.Colour.orange(),
description=question
)
out_msg = await ctx.reply(embed=embed)
# Wait for a tick/cross
asyncio.create_task(out_msg.add_reaction(tick))
asyncio.create_task(out_msg.add_reaction(cross))
def check(reaction, user):
result = True
result = result and reaction.message == out_msg
result = result and user == ctx.author
result = result and (reaction.emoji == tick or reaction.emoji == cross)
return result
try:
reaction, _ = await ctx.client.wait_for(
'reaction_add',
check=check,
timeout=120
)
except asyncio.TimeoutError:
try:
await out_msg.edit(
embed=discord.Embed(
colour=discord.Colour.red(),
description=timeout_msg or "Prompt timed out."
)
)
except discord.HTTPException:
pass
raise ResponseTimedOut from None
if reaction.emoji == cross:
try:
await out_msg.edit(
embed=discord.Embed(
colour=discord.Colour.red(),
description=cancel_msg or "Cancelled."
)
)
except discord.HTTPException:
pass
raise UserCancelled from None
try:
await out_msg.delete()
except discord.HTTPException:
pass
return True
_message_setting_flags = {
'removable': settings.removable,
'maximum': settings.maximum,
'required_role': settings.required_role,
'log': settings.log,
'refunds': settings.refunds,
'default_price': settings.default_price,
}
_reaction_setting_flags = {
'price': settings.price,
'duration': settings.duration
}
@module.cmd(
"reactionroles",
group="Guild Configuration",
desc="Create and configure reaction role messages.",
aliases=('rroles',),
flags=(
'delete', 'remove==',
'enable', 'disable',
'required_role==', 'removable=', 'maximum=', 'refunds=', 'log=', 'default_price=',
'price=', 'duration=='
)
)
@guild_admin()
async def cmd_reactionroles(ctx, flags):
"""
Usage``:
{prefix}rroles
{prefix}rroles [enable|disable|delete] msglink
{prefix}rroles msglink [emoji1 role1, emoji2 role2, ...]
{prefix}rroles msglink --remove emoji1, emoji2, ...
{prefix}rroles msglink --message_setting [value]
{prefix}rroles msglink emoji --reaction_setting [value]
Description:
Create and configure "reaction roles", i.e. roles obtainable by \
clicking reactions on a particular message.
`msglink` is the link or message id of the message with reactions.
`emoji` should be given as the emoji itself, or the name or id.
`role` may be given by name, mention, or id.
Getting started:
First choose the message you want to add reaction roles to, \
and copy the link or message id for that message. \
Then run the command `{prefix}rroles link`, replacing `link` with the copied link, \
and follow the prompts.
For faster setup, use `{prefix}rroles link emoji1 role1, emoji2 role2` instead.
Editing reaction roles:
Remove roles with `{prefix}rroles link --remove emoji1, emoji2, ...`
Add/edit roles with `{prefix}rroles link emoji1 role1, emoji2 role2, ...`
Examples``:
{prefix}rroles {ctx.msg.id} 🧮 mathematics, 🫀 biology, 🩺 medicine
{prefix}rroles disable {ctx.msg.id}
PAGEBREAK:
Page 2
Advanced configuration:
Type `{prefix}rroles link` again to view the advanced setting window, \
and use `{prefix}rroles link --setting value` to modify the settings. \
See below for descriptions of each message setting.
For example to disable event logging, run `{prefix}rroles link --log off`.
For per-reaction settings, instead use `{prefix}rroles link emoji --setting value`.
*(!) Replace `setting` with one of the settings below!*
Message Settings::
maximum: Maximum number of roles obtainable from this message.
log: Whether to log reaction role usage into the event log.
removable: Whether the reactions roles can be remove by unreacting.
refunds: Whether to refund the role price when removing the role.
default_price: The default price of each role on this message.
required_role: The role required to use these reactions roles.
Reaction Settings::
price: The price of this reaction role. (May be negative for a reward.)
tduration: How long this role will last after being selected or bought.
Configuration Examples``:
{prefix}rroles {ctx.msg.id} --maximum 5
{prefix}rroles {ctx.msg.id} --default_price 20
{prefix}rroles {ctx.msg.id} --required_role None
{prefix}rroles {ctx.msg.id} 🧮 --price 1024
{prefix}rroles {ctx.msg.id} 🧮 --duration 7 days
"""
if not ctx.args:
# No target message provided, list the current reaction messages
# Or give a brief guide if there are no current reaction messages
guild_message_rows = reaction_role_messages.fetch_rows_where(guildid=ctx.guild.id)
if guild_message_rows:
# List messages
# First get the list of reaction role messages in the guild
messages = [ReactionRoleMessage.fetch(row.messageid) for row in guild_message_rows]
# Sort them by channelid and messageid
messages.sort(key=lambda m: (m.data.channelid, m.messageid))
# Build the message description strings
message_strings = []
for message in messages:
header = (
"`{}` in <#{}> ([Click to jump]({})){}".format(
message.messageid,
message.data.channelid,
message.message_link,
" (disabled)" if not message.enabled else ""
)
)
role_strings = [
"{} <@&{}>".format(reaction.emoji, reaction.data.roleid)
for reaction in message.reactions
]
role_string = '\n'.join(role_strings) or "No reaction roles!"
message_strings.append("{}\n{}".format(header, role_string))
pages = []
page = []
page_len = 0
page_chars = 0
i = 0
while i < len(message_strings):
message_string = message_strings[i]
chars = len(message_string)
lines = len(message_string.splitlines())
if (page and lines + page_len > 20) or (chars + page_chars > 2000):
pages.append('\n\n'.join(page))
page = []
page_len = 0
page_chars = 0
else:
page.append(message_string)
page_len += lines
page_chars += chars
i += 1
if page:
pages.append('\n\n'.join(page))
page_count = len(pages)
title = "Reaction Roles in {}".format(ctx.guild.name)
embeds = [
discord.Embed(
colour=discord.Colour.orange(),
description=page,
title=title
)
for page in pages
]
if page_count > 1:
[embed.set_footer(text="Page {} of {}".format(i + 1, page_count)) for i, embed in enumerate(embeds)]
await ctx.pager(embeds)
else:
# Send a setup guide
embed = discord.Embed(
title="No Reaction Roles set up!",
description=(
"To setup reaction roles, first copy the link or message id of the message you want to "
"add the roles to. Then run `{prefix}rroles link`, replacing `link` with the link you copied, "
"and follow the prompts.\n"
"See `{prefix}help rroles` for more information.".format(prefix=ctx.best_prefix)
),
colour=discord.Colour.orange()
)
await ctx.reply(embed=embed)
return
# Extract first word, look for a subcommand
splits = ctx.args.split(maxsplit=1)
subcmd = splits[0].lower()
if subcmd in ('enable', 'disable', 'delete'):
# Truncate arguments and extract target
if len(splits) > 1:
ctx.args = splits[1]
target_chid, target_id = _parse_messageref(ctx)
else:
target_chid = None
target_id = None
ctx.args = ''
# Handle subcommand special cases
if subcmd == 'enable':
if ctx.args and not target_id:
await ctx.error_reply(
"Couldn't find the message to enable!\n"
"**Usage:** `{}rroles enable [message link or id]`.".format(ctx.best_prefix)
)
elif not target_id:
# Confirm enabling of all reaction messages
await reaction_ask(
ctx,
"Are you sure you want to enable all reaction role messages in this server?",
timeout_msg="Prompt timed out, no reaction roles enabled.",
cancel_msg="User cancelled, no reaction roles enabled."
)
reaction_role_messages.update_where(
{'enabled': True},
guildid=ctx.guild.id
)
await ctx.embed_reply(
"All reaction role messages have been enabled.",
colour=discord.Colour.green(),
)
else:
# Fetch the target
target = ReactionRoleMessage.fetch(target_id)
if target is None:
await ctx.error_reply(
"This message doesn't have any reaction roles!\n"
"Run the command again without `enable` to assign reaction roles."
)
else:
# We have a valid target
if target.enabled:
await ctx.error_reply(
"This message is already enabled!"
)
else:
target.enabled = True
await ctx.embed_reply(
"The message has been enabled!"
)
elif subcmd == 'disable':
if ctx.args and not target_id:
await ctx.error_reply(
"Couldn't find the message to disable!\n"
"**Usage:** `{}rroles disable [message link or id]`.".format(ctx.best_prefix)
)
elif not target_id:
# Confirm disabling of all reaction messages
await reaction_ask(
ctx,
"Are you sure you want to disable all reaction role messages in this server?",
timeout_msg="Prompt timed out, no reaction roles disabled.",
cancel_msg="User cancelled, no reaction roles disabled."
)
reaction_role_messages.update_where(
{'enabled': False},
guildid=ctx.guild.id
)
await ctx.embed_reply(
"All reaction role messages have been disabled.",
colour=discord.Colour.green(),
)
else:
# Fetch the target
target = ReactionRoleMessage.fetch(target_id)
if target is None:
await ctx.error_reply(
"This message doesn't have any reaction roles! Nothing to disable."
)
else:
# We have a valid target
if not target.enabled:
await ctx.error_reply(
"This message is already disabled!"
)
else:
target.enabled = False
await ctx.embed_reply(
"The message has been disabled!"
)
elif subcmd == 'delete':
if ctx.args and not target_id:
await ctx.error_reply(
"Couldn't find the message to remove!\n"
"**Usage:** `{}rroles remove [message link or id]`.".format(ctx.best_prefix)
)
elif not target_id:
# Confirm disabling of all reaction messages
await reaction_ask(
ctx,
"Are you sure you want to remove all reaction role messages in this server?",
timeout_msg="Prompt timed out, no messages removed.",
cancel_msg="User cancelled, no messages removed."
)
reaction_role_messages.delete_where(
guildid=ctx.guild.id
)
await ctx.embed_reply(
"All reaction role messages have been removed.",
colour=discord.Colour.green(),
)
else:
# Fetch the target
target = ReactionRoleMessage.fetch(target_id)
if target is None:
await ctx.error_reply(
"This message doesn't have any reaction roles! Nothing to remove."
)
else:
# We have a valid target
target.delete()
await ctx.embed_reply(
"The message has been removed and is no longer a reaction role message."
)
return
else:
# Just extract target
target_chid, target_id = _parse_messageref(ctx)
# Handle target parsing issue
if target_id is None:
return await ctx.error_reply(
"Couldn't parse `{}` as a message id or message link!\n"
"See `{}help rroles` for detailed usage information.".format(ctx.args.split()[0], ctx.best_prefix)
)
# Get the associated ReactionRoleMessage, if it exists
target = ReactionRoleMessage.fetch(target_id)
# Get the target message
if target:
message = await target.fetch_message()
if not message:
# TODO: Consider offering some sort of `move` option here.
await ctx.error_reply(
"This reaction role message no longer exists!\n"
"Use `{}rroles delete {}` to remove it from the list.".format(ctx.best_prefix, target.messageid)
)
else:
message = None
if target_chid:
channel = ctx.guild.get_channel(target_chid)
if not channel:
await ctx.error_reply(
"The provided channel no longer exists!"
)
elif not isinstance(channel, discord.TextChannel):
await ctx.error_reply(
"The provided channel is not a text channel!"
)
else:
message = await channel.fetch_message(target_id)
if not message:
await ctx.error_reply(
"Couldn't find the specified message in {}!".format(channel.mention)
)
else:
out_msg = await ctx.embed_reply("Searching for `{}`".format(target_id))
message = await ctx.find_message(target_id)
try:
await out_msg.delete()
except discord.HTTPException:
pass
if not message:
await ctx.error_reply(
"Couldn't find the message `{}`!".format(target_id)
)
if not message:
return
# Handle the `remove` flag specially
# In particular, all other flags are ignored
if flags['remove']:
if not target:
await ctx.error_reply(
"The specified message has no reaction roles! Nothing to remove."
)
else:
# Parse emojis and remove from target
target_emojis = {reaction.emoji: reaction for reaction in target.reactions}
emoji_name_map = {emoji.name.lower(): emoji for emoji in target_emojis}
emoji_id_map = {emoji.id: emoji for emoji in target_emojis}
items = [item.strip() for item in flags['remove'].split(',')]
to_remove = [] # List of reactions to remove
for emoji_str in items:
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
if emoji is None:
return await ctx.error_reply(
"Couldn't parse `{}` as an emoji! No reactions were removed.".format(emoji_str)
)
if emoji not in target_emojis:
return await ctx.error_reply(
"{} is not a reaction role for this message!".format(emoji)
)
to_remove.append(target_emojis[emoji])
# Delete reactions from data
description = '\n'.join("{} <@&{}>".format(reaction.emoji, reaction.data.roleid) for reaction in to_remove)
reaction_role_reactions.delete_where(reactionid=[reaction.reactionid for reaction in to_remove])
target.refresh()
# Ack
embed = discord.Embed(
colour=discord.Colour.green(),
title="Reaction Roles deactivated",
description=description
)
await ctx.reply(embed=embed)
return
# Any remaining arguments should be emoji specifications with optional role
# Parse these now
given_emojis = {} # Map PartialEmoji -> Optional[Role]
existing_emojis = set() # Set of existing reaction emoji identifiers
if ctx.args:
# First build the list of custom emojis we can accept by name
# We do this by reverse precedence, so the highest priority emojis are added last
custom_emojis = []
custom_emojis.extend(ctx.guild.emojis) # Custom emojis in the guild
if target:
custom_emojis.extend([r.emoji for r in target.reactions]) # Configured reaction roles on the target
custom_emojis.extend([r.emoji for r in message.reactions if r.custom_emoji]) # Actual reactions on the message
# Filter out the built in emojis and those without a name
custom_emojis = (emoji for emoji in custom_emojis if emoji.name and emoji.id)
# Build the maps to lookup provided custom emojis
emoji_name_map = {emoji.name.lower(): emoji for emoji in custom_emojis}
emoji_id_map = {emoji.id: emoji for emoji in custom_emojis}
# Now parse the provided emojis
# Assume that all-unicode strings are built-in emojis
# We can't assume much else unless we have a list of such emojis
splits = (split.strip() for line in ctx.args.splitlines() for split in line.split(',') if split)
splits = (split.split(maxsplit=1) for split in splits if split)
arg_emoji_strings = {
split[0]: split[1] if len(split) > 1 else None
for split in splits
} # emoji_str -> Optional[role_str]
arg_emoji_map = {}
for emoji_str, role_str in arg_emoji_strings.items():
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
if emoji is None:
return await ctx.error_reply(
"Couldn't parse `{}` as an emoji!".format(emoji_str)
)
else:
arg_emoji_map[emoji] = role_str
# Final pass extracts roles
# If any new emojis were provided, their roles should be specified, we enforce this during role parsing
# First collect the existing emoji strings
if target:
for reaction in target.reactions:
emoji_id = reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id
existing_emojis.add(emoji_id)
# Now parse and assign the roles, building the final map
for emoji, role_str in arg_emoji_map.items():
emoji_id = emoji.name if emoji.id is None else emoji.id
role = None
if role_str:
role = await ctx.find_role(role_str, create=True, interactive=True, allow_notfound=False)
elif emoji_id not in existing_emojis:
return await ctx.error_reply(
"New emoji {} was given without an associated role!".format(emoji)
)
given_emojis[emoji] = role
# Next manage target creation or emoji editing, if required
if target is None:
# Reaction message creation wizard
# Confirm that they want to create a new reaction role message.
await reaction_ask(
ctx,
question="Do you want to set up new reaction roles for [this message]({})?".format(
message.jump_url
),
timeout_msg="Prompt timed out, no reaction roles created.",
cancel_msg="Reaction Role creation cancelled."
)
# Continue with creation
# Obtain emojis if not already provided
if not given_emojis:
# Prompt for the initial emojis
embed = discord.Embed(
colour=discord.Colour.orange(),
title="What reaction roles would you like to add?",
description=(
"Please now type the reaction roles you would like to add "
"in the form `emoji role`, where `role` is given by partial name or id. For example:"
"```{}```".format(example_str)
)
)
out_msg = await ctx.reply(embed=embed)
# Wait for a response
def check(msg):
return msg.author == ctx.author and msg.channel == ctx.ch and msg.content
try:
reply = await ctx.client.wait_for('message', check=check, timeout=300)
except asyncio.TimeoutError:
try:
await out_msg.edit(
embed=discord.Embed(
colour=discord.Colour.red(),
description="Prompt timed out, no reaction roles created."
)
)
except discord.HTTPException:
pass
return
rolestrs = reply.content
try:
await reply.delete()
except discord.HTTPException:
pass
# Attempt to parse the emojis
# First build the list of custom emojis we can accept by name
custom_emojis = []
custom_emojis.extend(ctx.guild.emojis) # Custom emojis in the guild
custom_emojis.extend(
r.emoji for r in message.reactions if r.custom_emoji
) # Actual reactions on the message
# Filter out the built in emojis and those without a name
custom_emojis = (emoji for emoji in custom_emojis if emoji.name and emoji.id)
# Build the maps to lookup provided custom emojis
emoji_name_map = {emoji.name.lower(): emoji for emoji in custom_emojis}
emoji_id_map = {emoji.id: emoji for emoji in custom_emojis}
# Now parse the provided emojis
# Assume that all-unicode strings are built-in emojis
# We can't assume much else unless we have a list of such emojis
splits = (split.strip() for line in rolestrs.splitlines() for split in line.split(',') if split)
splits = (split.split(maxsplit=1) for split in splits if split)
arg_emoji_strings = {
split[0]: split[1] if len(split) > 1 else None
for split in splits
} # emoji_str -> Optional[role_str]
# Check all the emojis have roles associated
for emoji_str, role_str in arg_emoji_strings.items():
if role_str is None:
return await ctx.error_reply(
"No role provided for `{}`! Reaction role creation cancelled.".format(emoji_str)
)
# Parse the provided roles and emojis
for emoji_str, role_str in arg_emoji_strings.items():
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
if emoji is None:
return await ctx.error_reply(
"Couldn't parse `{}` as an emoji!".format(emoji_str)
)
else:
given_emojis[emoji] = await ctx.find_role(
role_str,
create=True,
interactive=True,
allow_notfound=False
)
if len(given_emojis) > 20:
return await ctx.error_reply("A maximum of 20 reactions are possible per message! Cancelling creation.")
# Create the ReactionRoleMessage
target = ReactionRoleMessage.create(
message.id,
message.guild.id,
message.channel.id
)
# Insert the reaction data directly
reaction_role_reactions.insert_many(
*((message.id, role.id, emoji.name, emoji.id, emoji.animated) for emoji, role in given_emojis.items()),
insert_keys=('messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated')
)
# Refresh the message to pick up the new reactions
target.refresh()
# Add the reactions to the message, if possible
existing_reactions = set(
reaction.emoji if not reaction.custom_emoji else
(reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id)
for reaction in message.reactions
)
missing = [
reaction.emoji for reaction in target.reactions
if (reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id) not in existing_reactions
]
if not any(emoji.id not in set(cemoji.id for cemoji in ctx.guild.emojis) for emoji in missing if emoji.id):
# We can add the missing emojis
for emoji in missing:
try:
await message.add_reaction(emoji)
except discord.HTTPException:
break
else:
missing = []
# Ack the creation
ack_msg = "Created `{}` new reaction roles on [this message]({})!".format(
len(target.reactions),
target.message_link
)
if missing:
ack_msg += "\nPlease add the missing reactions to the message!"
await ctx.embed_reply(
ack_msg
)
elif given_emojis:
# Update the target reactions
# Create a map of the emojis that need to be added or updated
needs_update = {
emoji: role for emoji, role in given_emojis.items() if role
}
# Fetch the existing target emojis to split the roles into inserts and updates
target_emojis = {reaction.emoji: reaction for reaction in target.reactions}
# Handle the new roles
insert_targets = {
emoji: role for emoji, role in needs_update.items() if emoji not in target_emojis
}
if insert_targets:
if len(insert_targets) + len(target_emojis) > 20:
return await ctx.error_reply("Too many reactions! A maximum of 20 reactions are possible per message!")
reaction_role_reactions.insert_many(
*(
(message.id, role.id, emoji.name, emoji.id, emoji.animated)
for emoji, role in insert_targets.items()
),
insert_keys=('messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated')
)
# Handle the updated roles
update_targets = {
target_emojis[emoji]: role for emoji, role in needs_update.items() if emoji in target_emojis
}
if update_targets:
reaction_role_reactions.update_many(
*((role.id, reaction.reactionid) for reaction, role in update_targets.items()),
set_keys=('roleid',),
where_keys=('reactionid',),
)
# Finally, refresh to load the new reactions
target.refresh()
# Now that the target is created/updated, all the provided emojis should be reactions
given_reactions = []
if given_emojis:
# Make a map of the existing reactions
existing_reactions = {
reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id: reaction
for reaction in target.reactions
}
given_reactions = [
existing_reactions[emoji.name if emoji.id is None else emoji.id]
for emoji in given_emojis
]
# Handle message setting updates
update_lines = [] # Setting update lines to display
update_columns = {} # Message data columns to update
for flag in _message_setting_flags:
if flags[flag]:
setting_class = _message_setting_flags[flag]
try:
setting = await setting_class.parse(target.messageid, ctx, flags[flag])
except UserInputError as e:
return await ctx.error_reply(
"{} {}\nNo settings were modified.".format(cross, e.msg),
title="Couldn't save settings!"
)
else:
update_lines.append(
"{} {}".format(tick, setting.success_response)
)
update_columns[setting._data_column] = setting.data
if update_columns:
# First write the data
reaction_role_messages.update_where(
update_columns,
messageid=target.messageid
)
# Then ack the setting update
if len(update_lines) > 1:
embed = discord.Embed(
colour=discord.Colour.green(),
title="Reaction Role message settings updated!",
description='\n'.join(update_lines)
)
else:
embed = discord.Embed(
colour=discord.Colour.green(),
description=update_lines[0]
)
await ctx.reply(embed=embed)
# Handle reaction setting updates
update_lines = [] # Setting update lines to display
update_columns = {} # Message data columns to update, for all given reactions
reactions = given_reactions or target.reactions
for flag in _reaction_setting_flags:
for reaction in reactions:
if flags[flag]:
setting_class = _reaction_setting_flags[flag]
try:
setting = await setting_class.parse(reaction.reactionid, ctx, flags[flag])
except UserInputError as e:
return await ctx.error_reply(
"{} {}\nNo reaction roles were modified.".format(cross, e.msg),
title="Couldn't save reaction role settings!",
)
else:
update_lines.append(
setting.success_response.format(reaction=reaction)
)
update_columns[setting._data_column] = setting.data
if update_columns:
# First write the data
reaction_role_reactions.update_where(
update_columns,
reactionid=[reaction.reactionid for reaction in reactions]
)
# Then ack the setting update
if len(update_lines) > 1:
blocks = ['\n'.join(update_lines[i:i+20]) for i in range(0, len(update_lines), 20)]
embeds = [
discord.Embed(
colour=discord.Colour.green(),
title="Reaction Role settings updated!",
description=block
) for block in blocks
]
await ctx.pager(embeds)
else:
embed = discord.Embed(
colour=discord.Colour.green(),
description=update_lines[0]
)
await ctx.reply(embed=embed)
# Show the reaction role message summary
# Build the reaction fields
reaction_fields = [] # List of tuples (name, value)
for reaction in target.reactions:
reaction_fields.append(
(
"{} {}".format(reaction.emoji.name, reaction.emoji if reaction.emoji.id else ''),
"<@&{}>\n{}".format(reaction.data.roleid, reaction.settings.tabulated())
)
)
# Build the final setting pages
description = (
"{settings_table}\n"
"To update a message setting: `{prefix}rroles messageid --setting value`\n"
"To update an emoji setting: `{prefix}rroles messageid emoji --setting value`\n"
"See examples and more usage information with `{prefix}help rroles`.\n"
"**(!) Replace the `setting` with one of the settings on this page.**\n"
).format(
prefix=ctx.best_prefix,
settings_table=target.settings.tabulated()
)
field_blocks = [reaction_fields[i:i+6] for i in range(0, len(reaction_fields), 6)]
page_count = len(field_blocks)
embeds = []
for i, block in enumerate(field_blocks):
title = "Reaction role settings for message id `{}`".format(target.messageid)
embed = discord.Embed(
title=title,
description=description
).set_author(
name="Click to jump to message",
url=target.message_link
)
for name, value in block:
embed.add_field(name=name, value=value)
if page_count > 1:
embed.set_footer(text="Page {} of {}".format(i+1, page_count))
embeds.append(embed)
# Finally, send the reaction role information
await ctx.pager(embeds)

View File

@@ -0,0 +1,22 @@
from data import Table, RowTable
reaction_role_messages = RowTable(
'reaction_role_messages',
('messageid', 'guildid', 'channelid',
'enabled',
'required_role', 'allow_deselction',
'max_obtainable', 'allow_refunds',
'event_log'),
'messageid'
)
reaction_role_reactions = RowTable(
'reaction_role_reactions',
('reactionid', 'messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated', 'price', 'timeout'),
'reactionid'
)
reaction_role_expiring = Table('reaction_role_expiring')

View File

@@ -0,0 +1,172 @@
import logging
import traceback
import asyncio
import discord
from meta import client
from utils.lib import utc_now
from settings import GuildSettings
from .module import module
from .data import reaction_role_expiring
_expiring = {}
_wakeup_event = asyncio.Event()
# TODO: More efficient data structure for min optimisation, e.g. pre-sorted with bisection insert
# Public expiry interface
def schedule_expiry(guildid, userid, roleid, expiry, reactionid=None):
"""
Schedule expiry of the given role for the given member at the given time.
This will also cancel any existing expiry for this member, role pair.
"""
reaction_role_expiring.delete_where(
guildid=guildid,
userid=userid,
roleid=roleid,
)
reaction_role_expiring.insert(
guildid=guildid,
userid=userid,
roleid=roleid,
expiry=expiry,
reactionid=reactionid
)
key = (guildid, userid, roleid)
_expiring[key] = expiry.timestamp()
_wakeup_event.set()
def cancel_expiry(*key):
"""
Cancel expiry for the given member and role, if it exists.
"""
guildid, userid, roleid = key
reaction_role_expiring.delete_where(
guildid=guildid,
userid=userid,
roleid=roleid,
)
if _expiring.pop(key, None) is not None:
# Wakeup the expiry tracker for recalculation
_wakeup_event.set()
def _next():
"""
Calculate the next member, role pair to expire.
"""
if _expiring:
key, _ = min(_expiring.items(), key=lambda pair: pair[1])
return key
else:
return None
async def _expire(key):
"""
Execute reaction role expiry for the given member and role.
This removes the role and logs the removal if applicable.
If the user is no longer in the guild, it removes the role from the persistent roles instead.
"""
guildid, userid, roleid = key
guild = client.get_guild(guildid)
if guild:
role = guild.get_role(roleid)
if role:
member = guild.get_member(userid)
if member:
log = GuildSettings(guildid).event_log.log
if role in member.roles:
# Remove role from member, and log if applicable
try:
await member.remove_roles(
role,
atomic=True,
reason="Expiring temporary reaction role."
)
except discord.HTTPException:
log(
"Failed to remove expired reaction role {} from {}.".format(
role.mention,
member.mention
),
colour=discord.Colour.red(),
title="Could not remove expired Reaction Role!"
)
else:
log(
"Removing expired reaction role {} from {}.".format(
role.mention,
member.mention
),
title="Reaction Role expired!"
)
else:
# Remove role from stored persistent roles, if existent
client.data.past_member_roles.delete_where(
guildid=guildid,
userid=userid,
roleid=roleid
)
reaction_role_expiring.delete_where(
guildid=guildid,
userid=userid,
roleid=roleid
)
async def _expiry_tracker(client):
"""
Track and launch role expiry.
"""
while True:
try:
key = _next()
diff = _expiring[key] - utc_now().timestamp() if key else None
await asyncio.wait_for(_wakeup_event.wait(), timeout=diff)
except asyncio.TimeoutError:
# Timeout means next doesn't exist or is ready to expire
if key and key in _expiring and _expiring[key] <= utc_now().timestamp() + 1:
_expiring.pop(key)
asyncio.create_task(_expire(key))
except Exception:
# This should be impossible, but catch and log anyway
client.log(
"Exception occurred while tracking reaction role expiry. Exception traceback follows.\n{}".format(
traceback.format_exc()
),
context="REACTION_ROLE_EXPIRY",
level=logging.ERROR
)
else:
# Wakeup event means that we should recalculate next
_wakeup_event.clear()
@module.launch_task
async def launch_expiry_tracker(client):
"""
Launch the role expiry tracker.
"""
asyncio.create_task(_expiry_tracker(client))
client.log("Reaction role expiry tracker launched.", context="REACTION_ROLE_EXPIRY")
@module.init_task
def load_expiring_roles(client):
"""
Initialise the expiring reaction role map, and attach it to the client.
"""
rows = reaction_role_expiring.select_where()
_expiring.clear()
_expiring.update({(row['guildid'], row['userid'], row['roleid']): row['expiry'].timestamp() for row in rows})
client.objects['expiring_reaction_roles'] = _expiring
if _expiring:
client.log(
"Loaded {} expiring reaction roles.".format(len(_expiring)),
context="REACTION_ROLE_EXPIRY"
)

View File

@@ -0,0 +1,3 @@
from LionModule import LionModule
module = LionModule("Reaction_Roles")

View File

@@ -0,0 +1,257 @@
from utils.lib import DotDict
from wards import guild_admin
from settings import ObjectSettings, ColumnData, Setting
import settings.setting_types as setting_types
from .data import reaction_role_messages, reaction_role_reactions
class RoleMessageSettings(ObjectSettings):
settings = DotDict()
class RoleMessageSetting(ColumnData, Setting):
_table_interface = reaction_role_messages
_id_column = 'messageid'
_create_row = False
write_ward = guild_admin
@RoleMessageSettings.attach_setting
class required_role(setting_types.Role, RoleMessageSetting):
attr_name = 'required_role'
_data_column = 'required_role'
display_name = "required_role"
desc = "Role required to use the reaction roles."
long_desc = (
"Members will be required to have the specified role to use the reactions on this message."
)
@property
def success_response(self):
if self.value:
return "Members need {} to use these reaction roles.".format(self.formatted)
else:
return "All members can now use these reaction roles."
@classmethod
def _get_guildid(cls, id: int, **kwargs):
return reaction_role_messages.fetch(id).guildid
@RoleMessageSettings.attach_setting
class removable(setting_types.Boolean, RoleMessageSetting):
attr_name = 'removable'
_data_column = 'removable'
display_name = "removable"
desc = "Whether the role is removable by deselecting the reaction."
long_desc = (
"If enabled, the role will be removed when the reaction is deselected."
)
_default = True
@property
def success_response(self):
if self.value:
return "Members will be able to remove roles by unreacting."
else:
return "Members will not be able to remove the reaction roles."
@RoleMessageSettings.attach_setting
class maximum(setting_types.Integer, RoleMessageSetting):
attr_name = 'maximum'
_data_column = 'maximum'
display_name = "maximum"
desc = "The maximum number of roles a member can get from this message."
long_desc = (
"The maximum number of roles that a member can get from this message. "
"They will be notified by DM if they attempt to add more.\n"
"The `removable` setting should generally be enabled with this setting."
)
accepts = "An integer number of roles, or `None` to remove the maximum."
_min = 0
@classmethod
def _format_data(cls, id, data, **kwargs):
if data is None:
return "No maximum!"
else:
return "`{}`".format(data)
@property
def success_response(self):
if self.value:
return "Members can get a maximum of `{}` roles from this message.".format(self.value)
else:
return "Members can now get all the roles from this mesage."
@RoleMessageSettings.attach_setting
class refunds(setting_types.Boolean, RoleMessageSetting):
attr_name = 'refunds'
_data_column = 'refunds'
display_name = "refunds"
desc = "Whether a user will be refunded when they deselect a role."
long_desc = (
"Whether to give the user a refund when they deselect a role by reaction. "
"This has no effect if `removable` is not enabled, or if the role removed has no cost."
)
_default = True
@property
def success_response(self):
if self.value:
return "Members will get a refund when they remove a role."
else:
return "Members will not get a refund when they remove a role."
@RoleMessageSettings.attach_setting
class default_price(setting_types.Integer, RoleMessageSetting):
attr_name = 'default_price'
_data_column = 'default_price'
display_name = "default_price"
desc = "Default price of reaction roles on this message."
long_desc = (
"Reaction roles on this message will have this cost if they do not have an individual price set."
)
accepts = "An integer number of coins. Use `0` or `None` to make roles free by default."
_default = 0
@classmethod
def _format_data(cls, id, data, **kwargs):
if not data:
return "Free"
else:
return "`{}` coins".format(data)
@property
def success_response(self):
if self.value:
return "Reaction roles on this message will cost `{}` coins by default.".format(self.value)
else:
return "Reaction roles on this message will be free by default."
@RoleMessageSettings.attach_setting
class log(setting_types.Boolean, RoleMessageSetting):
attr_name = 'log'
_data_column = 'event_log'
display_name = "log"
desc = "Whether to log reaction role usage in the event log."
long_desc = (
"When enabled, roles added or removed with reactions will be logged in the configured event log."
)
_default = True
@property
def success_response(self):
if self.value:
return "Role updates will now be logged."
else:
return "Role updates will not be logged."
class ReactionSettings(ObjectSettings):
settings = DotDict()
class ReactionSetting(ColumnData, Setting):
_table_interface = reaction_role_reactions
_id_column = 'reactionid'
_create_row = False
write_ward = guild_admin
@ReactionSettings.attach_setting
class price(setting_types.Integer, ReactionSetting):
attr_name = 'price'
_data_column = 'price'
display_name = "price"
desc = "Price of this reaction role (may be negative)."
long_desc = (
"The number of coins that will be deducted from the user when this reaction is used.\n"
"The number may be negative, in order to give a reward when the member choses the reaction."
)
accepts = "An integer number of coins. Use `0` to make the role free, or `None` to use the message default."
_max = 2 ** 20
@property
def default(self):
"""
The default price is given by the ReactionMessage price setting.
"""
return default_price.get(self._table_interface.fetch(self.id).messageid).value
@classmethod
def _format_data(cls, id, data, **kwargs):
if not data:
return "Free"
else:
return "`{}` coins".format(data)
@property
def success_response(self):
if self.value is not None:
return "{{reaction.emoji}} {{reaction.role.mention}} now costs `{}` coins.".format(self.value)
else:
return "{reaction.emoji} {reaction.role.mention} is now free."
@ReactionSettings.attach_setting
class duration(setting_types.Duration, ReactionSetting):
attr_name = 'duration'
_data_column = 'timeout'
display_name = "duration"
desc = "How long this reaction role will last."
long_desc = (
"If set, the reaction role will be removed after the configured duration. "
"Note that this does not affect existing members with the role, or existing expiries."
)
_default_multiplier = 3600
_show_days = True
_min = 600
@classmethod
def _format_data(cls, id, data, **kwargs):
if data is None:
return "Permanent"
else:
return super()._format_data(id, data, **kwargs)
@property
def success_response(self):
if self.value is not None:
return "{{reaction.emoji}} {{reaction.role.mention}} will expire `{}` after selection.".format(
self.formatted
)
else:
return "{reaction.emoji} {reaction.role.mention} will not expire."

View File

@@ -0,0 +1,590 @@
import asyncio
from codecs import ignore_errors
import logging
import traceback
import datetime
from collections import defaultdict
from typing import List, Mapping, Optional
from cachetools import LFUCache
import discord
from discord import PartialEmoji
from meta import client
from core import Lion
from data import Row
from data.conditions import THIS_SHARD
from utils.lib import utc_now
from settings import GuildSettings
from ..module import module
from .data import reaction_role_messages, reaction_role_reactions
from .settings import RoleMessageSettings, ReactionSettings
from .expiry import schedule_expiry, cancel_expiry
class ReactionRoleReaction:
"""
Light data class representing a reaction role reaction.
"""
__slots__ = ('reactionid', '_emoji', '_message', '_role')
def __init__(self, reactionid, message=None, **kwargs):
self.reactionid = reactionid
self._message: ReactionRoleMessage = None
self._role = None
self._emoji = None
@classmethod
def create(cls, messageid, roleid, emoji: PartialEmoji, message=None, **kwargs) -> 'ReactionRoleReaction':
"""
Create a new ReactionRoleReaction with the provided attributes.
`emoji` sould be provided as a PartialEmoji.
`kwargs` are passed transparently to the `insert` method.
"""
row = reaction_role_reactions.create_row(
messageid=messageid,
roleid=roleid,
emoji_name=emoji.name,
emoji_id=emoji.id,
emoji_animated=emoji.animated,
**kwargs
)
return cls(row.reactionid, message=message)
@property
def emoji(self) -> PartialEmoji:
if self._emoji is None:
data = self.data
self._emoji = PartialEmoji(
name=data.emoji_name,
animated=data.emoji_animated,
id=data.emoji_id,
)
return self._emoji
@property
def data(self) -> Row:
return reaction_role_reactions.fetch(self.reactionid)
@property
def settings(self) -> ReactionSettings:
return ReactionSettings(self.reactionid)
@property
def reaction_message(self):
if self._message is None:
self._message = ReactionRoleMessage.fetch(self.data.messageid)
return self._message
@property
def role(self):
if self._role is None:
guild = self.reaction_message.guild
if guild:
self._role = guild.get_role(self.data.roleid)
return self._role
class ReactionRoleMessage:
"""
Light data class representing a reaction role message.
Primarily acts as an interface to the corresponding Settings.
"""
__slots__ = ('messageid', '_message')
# Full live messageid cache for this client. Should always be up to date.
_messages: Mapping[int, 'ReactionRoleMessage'] = {} # messageid -> associated Reaction message
# Reaction cache for the live messages. Least frequently used, will be fetched on demand.
_reactions: Mapping[int, List[ReactionRoleReaction]] = LFUCache(1000) # messageid -> List of Reactions
# User-keyed locks so we only handle one reaction per user at a time
_locks: Mapping[int, asyncio.Lock] = defaultdict(asyncio.Lock) # userid -> Lock
def __init__(self, messageid):
self.messageid = messageid
self._message = None
@classmethod
def fetch(cls, messageid) -> 'ReactionRoleMessage':
"""
Fetch the ReactionRoleMessage for the provided messageid.
Returns None if the messageid is not registered.
"""
# Since the cache is assumed to be always up to date, just pass to fetch-from-cache.
return cls._messages.get(messageid, None)
@classmethod
def create(cls, messageid, guildid, channelid, **kwargs) -> 'ReactionRoleMessage':
"""
Create a ReactionRoleMessage with the given `messageid`.
Other `kwargs` are passed transparently to `insert`.
"""
# Insert the data
reaction_role_messages.create_row(
messageid=messageid,
guildid=guildid,
channelid=channelid,
**kwargs
)
# Create the ReactionRoleMessage
rmsg = cls(messageid)
# Add to the global cache
cls._messages[messageid] = rmsg
# Return the constructed ReactionRoleMessage
return rmsg
def delete(self):
"""
Delete this ReactionRoleMessage.
"""
# Remove message from cache
self._messages.pop(self.messageid, None)
# Remove reactions from cache
reactionids = [reaction.reactionid for reaction in self.reactions]
[self._reactions.pop(reactionid, None) for reactionid in reactionids]
# Remove message from data
reaction_role_messages.delete_where(messageid=self.messageid)
@property
def data(self) -> Row:
"""
Data row associated with this Message.
Passes directly to the RowTable cache.
Should not generally be used directly, use the settings interface instead.
"""
return reaction_role_messages.fetch(self.messageid)
@property
def settings(self):
"""
RoleMessageSettings associated to this Message.
"""
return RoleMessageSettings(self.messageid)
def refresh(self):
"""
Refresh the reaction cache for this message.
Returns the generated `ReactionRoleReaction`s for convenience.
"""
# Fetch reactions and pre-populate reaction cache
rows = reaction_role_reactions.fetch_rows_where(messageid=self.messageid, _extra="ORDER BY reactionid ASC")
reactions = [ReactionRoleReaction(row.reactionid) for row in rows]
self._reactions[self.messageid] = reactions
return reactions
@property
def reactions(self) -> List[ReactionRoleReaction]:
"""
Returns the list of active reactions for this message, as `ReactionRoleReaction`s.
Lazily fetches the reactions from data if they have not been loaded.
"""
reactions = self._reactions.get(self.messageid, None)
if reactions is None:
reactions = self.refresh()
return reactions
@property
def enabled(self) -> bool:
"""
Whether this Message is enabled.
Passes directly to data for efficiency.
"""
return self.data.enabled
@enabled.setter
def enabled(self, value: bool):
self.data.enabled = value
# Discord properties
@property
def guild(self) -> discord.Guild:
return client.get_guild(self.data.guildid)
@property
def channel(self) -> discord.TextChannel:
return client.get_channel(self.data.channelid)
async def fetch_message(self) -> discord.Message:
if self._message:
return self._message
channel = self.channel
if channel:
try:
self._message = await channel.fetch_message(self.messageid)
return self._message
except discord.NotFound:
# The message no longer exists
# TODO: Cache and data cleanup? Or allow moving after death?
pass
@property
def message(self) -> Optional[discord.Message]:
return self._message
@property
def message_link(self) -> str:
"""
Jump link tho the reaction message.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
self.data.guildid,
self.data.channelid,
self.messageid
)
# Event handlers
async def process_raw_reaction_add(self, payload):
"""
Process a general reaction add payload.
"""
event_log = GuildSettings(self.guild.id).event_log
async with self._locks[payload.user_id]:
reaction = next((reaction for reaction in self.reactions if reaction.emoji == payload.emoji), None)
if reaction:
# User pressed a live reaction. Process!
member = payload.member
lion = Lion.fetch(member.guild.id, member.id)
role = reaction.role
if reaction.role and (role not in member.roles):
# Required role check, make sure the user has the required role, if set.
required_role = self.settings.required_role.value
if required_role and required_role not in member.roles:
# Silently remove their reaction
try:
message = await self.fetch_message()
await message.remove_reaction(
payload.emoji,
member
)
except discord.HTTPException:
pass
return
# Maximum check, check whether the user already has too many roles from this message.
maximum = self.settings.maximum.value
if maximum is not None:
# Fetch the number of applicable roles the user has
roleids = set(reaction.data.roleid for reaction in self.reactions)
member_roleids = set(role.id for role in member.roles)
if len(roleids.intersection(member_roleids)) >= maximum:
# Notify the user
embed = discord.Embed(
title="Maximum group roles reached!",
description=(
"Couldn't give you **{}**, "
"because you already have `{}` roles from this group!".format(
role.name,
maximum
)
)
)
# Silently try to notify the user
try:
await member.send(embed=embed)
except discord.HTTPException:
pass
# Silently remove the reaction
try:
message = await self.fetch_message()
await message.remove_reaction(
payload.emoji,
member
)
except discord.HTTPException:
pass
return
# Economy hook, check whether the user can pay for the role.
price = reaction.settings.price.value
if price and price > lion.coins:
# They can't pay!
# Build the can't pay embed
embed = discord.Embed(
title="Insufficient funds!",
description="Sorry, **{}** costs `{}` coins, but you only have `{}`.".format(
role.name,
price,
lion.coins
),
colour=discord.Colour.red()
).set_footer(
icon_url=self.guild.icon_url,
text=self.guild.name
).add_field(
name="Jump Back",
value="[Click here]({})".format(self.message_link)
)
# Try to send them the embed, ignore errors
try:
await member.send(
embed=embed
)
except discord.HTTPException:
pass
# Remove their reaction, ignore errors
try:
message = await self.fetch_message()
await message.remove_reaction(
payload.emoji,
member
)
except discord.HTTPException:
pass
return
# Add the role
try:
await member.add_roles(
role,
atomic=True,
reason="Adding reaction role."
)
except discord.Forbidden:
event_log.log(
"Insufficient permissions to give {} the [reaction role]({}) {}".format(
member.mention,
self.message_link,
role.mention,
),
title="Failed to add reaction role",
colour=discord.Colour.red()
)
except discord.HTTPException:
event_log.log(
"Something went wrong while adding the [reaction role]({}) "
"{} to {}.".format(
self.message_link,
role.mention,
member.mention
),
title="Failed to add reaction role",
colour=discord.Colour.red()
)
client.log(
"Unexpected HTTPException encountered while adding '{}' (rid:{}) to "
"user '{}' (uid:{}) in guild '{}' (gid:{}).\n{}".format(
role.name,
role.id,
member,
member.id,
member.guild.name,
member.guild.id,
traceback.format_exc()
),
context="REACTION_ROLE_ADD",
level=logging.WARNING
)
else:
# Charge the user and notify them, if the price is set
if price:
lion.addCoins(-price)
# Notify the user of their purchase
embed = discord.Embed(
title="Purchase successful!",
description="You have purchased **{}** for `{}` coins!".format(
role.name,
price
),
colour=discord.Colour.green()
).set_footer(
icon_url=self.guild.icon_url,
text=self.guild.name
).add_field(
name="Jump Back",
value="[Click Here]({})".format(self.message_link)
)
try:
await member.send(embed=embed)
except discord.HTTPException:
pass
# Schedule the expiry, if required
duration = reaction.settings.duration.value
if duration:
expiry = utc_now() + datetime.timedelta(seconds=duration)
schedule_expiry(self.guild.id, member.id, role.id, expiry, reaction.reactionid)
else:
expiry = None
# Log the role modification if required
if self.settings.log.value:
event_log.log(
"Added [reaction role]({}) {} "
"to {}{}.{}".format(
self.message_link,
role.mention,
member.mention,
" for `{}` coins".format(price) if price else '',
"\nThis role will expire at <t:{:.0f}>.".format(
expiry.timestamp()
) if expiry else ''
),
title="Reaction Role Added"
)
async def process_raw_reaction_remove(self, payload):
"""
Process a general reaction remove payload.
"""
if self.settings.removable.value:
event_log = GuildSettings(self.guild.id).event_log
async with self._locks[payload.user_id]:
reaction = next((reaction for reaction in self.reactions if reaction.emoji == payload.emoji), None)
if reaction:
# User removed a live reaction. Process!
member = self.guild.get_member(payload.user_id)
role = reaction.role
if member and not member.bot and role and (role in member.roles):
# Check whether they have the required role, if set
required_role = self.settings.required_role.value
if required_role and required_role not in member.roles:
# Ignore the reaction removal
return
try:
await member.remove_roles(
role,
atomic=True,
reason="Removing reaction role."
)
except discord.Forbidden:
event_log.log(
"Insufficient permissions to remove "
"the [reaction role]({}) {} from {}".format(
self.message_link,
role.mention,
member.mention,
),
title="Failed to remove reaction role",
colour=discord.Colour.red()
)
except discord.HTTPException:
event_log.log(
"Something went wrong while removing the [reaction role]({}) "
"{} from {}.".format(
self.message_link,
role.mention,
member.mention
),
title="Failed to remove reaction role",
colour=discord.Colour.red()
)
client.log(
"Unexpected HTTPException encountered while removing '{}' (rid:{}) from "
"user '{}' (uid:{}) in guild '{}' (gid:{}).\n{}".format(
role.name,
role.id,
member,
member.id,
member.guild.name,
member.guild.id,
traceback.format_exc()
),
context="REACTION_ROLE_RM",
level=logging.WARNING
)
else:
# Economy hook, handle refund if required
price = reaction.settings.price.value
refund = self.settings.refunds.value
if price and refund:
# Give the user the refund
lion = Lion.fetch(self.guild.id, member.id)
lion.addCoins(price)
# Notify the user
embed = discord.Embed(
title="Role sold",
description=(
"You sold the role **{}** for `{}` coins.".format(
role.name,
price
)
),
colour=discord.Colour.green()
).set_footer(
icon_url=self.guild.icon_url,
text=self.guild.name
).add_field(
name="Jump Back",
value="[Click Here]({})".format(self.message_link)
)
try:
await member.send(embed=embed)
except discord.HTTPException:
pass
# Log role removal if required
if self.settings.log.value:
event_log.log(
"Removed [reaction role]({}) {} "
"from {}.".format(
self.message_link,
role.mention,
member.mention
),
title="Reaction Role Removed"
)
# Cancel any existing expiry
cancel_expiry(self.guild.id, member.id, role.id)
# TODO: Make all the embeds a bit nicer, and maybe make a consistent interface for them
# TODO: Handle RawMessageDelete event
# TODO: Handle permission errors when fetching message in config
@client.add_after_event('raw_reaction_add')
async def reaction_role_add(client, payload):
reaction_message = ReactionRoleMessage.fetch(payload.message_id)
if payload.guild_id and payload.user_id != client.user.id and reaction_message and reaction_message.enabled:
try:
await reaction_message.process_raw_reaction_add(payload)
except Exception:
# Unknown exception, catch and log it.
client.log(
"Unhandled exception while handling reaction message payload: {}\n{}".format(
payload,
traceback.format_exc()
),
context="REACTION_ROLE_ADD",
level=logging.ERROR
)
@client.add_after_event('raw_reaction_remove')
async def reaction_role_remove(client, payload):
reaction_message = ReactionRoleMessage.fetch(payload.message_id)
if payload.guild_id and reaction_message and reaction_message.enabled:
try:
await reaction_message.process_raw_reaction_remove(payload)
except Exception:
# Unknown exception, catch and log it.
client.log(
"Unhandled exception while handling reaction message payload: {}\n{}".format(
payload,
traceback.format_exc()
),
context="REACTION_ROLE_RM",
level=logging.ERROR
)
@module.init_task
def load_reaction_roles(client):
"""
Load the ReactionRoleMessages.
"""
rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD)
ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows}

View File

@@ -0,0 +1,65 @@
from io import StringIO
import discord
from wards import guild_admin
from data import tables
from core import Lion
from .module import module
@module.cmd("studyreset",
desc="Perform a reset of the server's study statistics.",
group="Guild Admin")
@guild_admin()
async def cmd_statreset(ctx):
"""
Usage``:
{prefix}studyreset
Description:
Perform a complete reset of the server's study statistics.
That is, deletes the tracked time of all members and removes their study badges.
This may be used to set "seasons" of study.
Before the reset, I will send a csv file with the current member statistics.
**This is not reversible.**
"""
if not await ctx.ask("Are you sure you want to reset the study time and badges for all members? "
"**THIS IS NOT REVERSIBLE!**"):
return
# Build the data csv
rows = tables.lions.select_where(
select_columns=('userid', 'tracked_time', 'coins', 'workout_count', 'b.roleid AS badge_roleid'),
_extra=(
"LEFT JOIN study_badges b ON last_study_badgeid = b.badgeid "
"WHERE members.guildid={}"
).format(ctx.guild.id)
)
header = "userid, tracked_time, coins, workouts, rank_roleid\n"
csv_rows = [
','.join(str(data) for data in row)
for row in rows
]
with StringIO() as stats_file:
stats_file.write(header)
stats_file.write('\n'.join(csv_rows))
stats_file.seek(0)
out_file = discord.File(stats_file, filename="guild_{}_member_statistics.csv".format(ctx.guild.id))
await ctx.reply(file=out_file)
# Reset the statistics
tables.lions.update_where(
{'tracked_time': 0},
guildid=ctx.guild.id
)
Lion.sync()
await ctx.embed_reply(
"The member study times have been reset!\n"
"(It may take a while for the studybadges to update.)"
)

View File

@@ -0,0 +1,7 @@
# flake8: noqa
from .module import module
from . import help
from . import links
from . import nerd
from . import join_message

View File

@@ -0,0 +1,237 @@
import discord
from cmdClient.checks import is_owner
from utils.lib import prop_tabulate
from utils import interactive, ctx_addons # noqa
from wards import is_guild_admin
from .module import module
from .lib import guide_link
new_emoji = " 🆕"
new_commands = {'botconfig', 'sponsors'}
# Set the command groups to appear in the help
group_hints = {
'Pomodoro': "*Stay in sync with your friends using our timers!*",
'Productivity': "*Use these to help you stay focused and productive!*",
'Statistics': "*StudyLion leaderboards and study statistics.*",
'Economy': "*Buy, sell, and trade with your hard-earned coins!*",
'Personal Settings': "*Tell me about yourself!*",
'Guild Admin': "*Dangerous administration commands!*",
'Guild Configuration': "*Control how I behave in your server.*",
'Meta': "*Information about me!*",
'Support Us': "*Support the team and keep the project alive by using LionGems!*"
}
standard_group_order = (
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings', 'Meta'),
)
mod_group_order = (
('Moderation', 'Meta'),
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
)
admin_group_order = (
('Guild Admin', 'Guild Configuration', 'Moderation', 'Meta'),
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
)
bot_admin_group_order = (
('Bot Admin', 'Guild Admin', 'Guild Configuration', 'Moderation', 'Meta'),
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
)
# Help embed format
# TODO: Add config fields for this
title = "StudyLion Command List"
header = """
[StudyLion](https://bot.studylions.com/) is a fully featured study assistant \
that tracks your study time and offers productivity tools \
such as to-do lists, task reminders, private study rooms, group accountability sessions, and much much more.\n
Use `{{ctx.best_prefix}}help <command>` (e.g. `{{ctx.best_prefix}}help send`) to learn how to use each command, \
or [click here]({guide_link}) for a comprehensive tutorial.
""".format(guide_link=guide_link)
@module.cmd("help",
group="Meta",
desc="StudyLion command list.",
aliases=('man', 'ls', 'list'))
async def cmd_help(ctx):
"""
Usage``:
{prefix}help [cmdname]
Description:
When used with no arguments, displays a list of commands with brief descriptions.
Otherwise, shows documentation for the provided command.
Examples:
{prefix}help
{prefix}help top
{prefix}help timezone
"""
if ctx.arg_str:
# Attempt to fetch the command
command = ctx.client.cmd_names.get(ctx.arg_str.strip(), None)
if command is None:
return await ctx.error_reply(
("Command `{}` not found!\n"
"Write `{}help` to see a list of commands.").format(ctx.args, ctx.best_prefix)
)
smart_help = getattr(command, 'smart_help', None)
if smart_help is not None:
return await smart_help(ctx)
help_fields = command.long_help.copy()
help_map = {field_name: i for i, (field_name, _) in enumerate(help_fields)}
if not help_map:
return await ctx.reply("No documentation has been written for this command yet!")
field_pages = [[]]
page_fields = field_pages[0]
for name, pos in help_map.items():
if name.endswith("``"):
# Handle codeline help fields
page_fields.append((
name.strip("`"),
"`{}`".format('`\n`'.join(help_fields[pos][1].splitlines()))
))
elif name.endswith(":"):
# Handle property/value help fields
lines = help_fields[pos][1].splitlines()
names = []
values = []
for line in lines:
split = line.split(":", 1)
names.append(split[0] if len(split) > 1 else "")
values.append(split[-1])
page_fields.append((
name.strip(':'),
prop_tabulate(names, values)
))
elif name == "Related":
# Handle the related field
names = [cmd_name.strip() for cmd_name in help_fields[pos][1].split(',')]
names.sort(key=len)
values = [
(getattr(ctx.client.cmd_names.get(cmd_name, None), 'desc', '') or '').format(ctx=ctx)
for cmd_name in names
]
page_fields.append((
name,
prop_tabulate(names, values)
))
elif name == "PAGEBREAK":
page_fields = []
field_pages.append(page_fields)
else:
page_fields.append((name, help_fields[pos][1]))
# Build the aliases
aliases = getattr(command, 'aliases', [])
alias_str = "(Aliases `{}`.)".format("`, `".join(aliases)) if aliases else ""
# Build the embeds
pages = []
for i, page_fields in enumerate(field_pages):
embed = discord.Embed(
title="`{}` command documentation. {}".format(
command.name,
alias_str
),
colour=discord.Colour(0x9b59b6)
)
for fieldname, fieldvalue in page_fields:
embed.add_field(
name=fieldname,
value=fieldvalue.format(ctx=ctx, prefix=ctx.best_prefix),
inline=False
)
embed.set_footer(
text="{}\n[optional] and <required> denote optional and required arguments, respectively.".format(
"Page {} of {}".format(i + 1, len(field_pages)) if len(field_pages) > 1 else '',
)
)
pages.append(embed)
# Post the embed
await ctx.pager(pages)
else:
# Build the command groups
cmd_groups = {}
for command in ctx.client.cmds:
# Get the command group
group = getattr(command, 'group', "Misc")
cmd_group = cmd_groups.get(group, [])
if not cmd_group:
cmd_groups[group] = cmd_group
# Add the command name and description to the group
cmd_group.append(
(command.name, (getattr(command, 'desc', '') + (new_emoji if command.name in new_commands else '')))
)
# Add any required aliases
for alias, desc in getattr(command, 'help_aliases', {}).items():
cmd_group.append((alias, desc))
# Turn the command groups into strings
stringy_cmd_groups = {}
for group_name, cmd_group in cmd_groups.items():
cmd_group.sort(key=lambda tup: len(tup[0]))
if ctx.alias == 'ls':
stringy_cmd_groups[group_name] = ', '.join(
f"`{name}`" for name, _ in cmd_group
)
else:
stringy_cmd_groups[group_name] = prop_tabulate(*zip(*cmd_group))
# Now put everything into a bunch of embeds
if await is_owner.run(ctx):
group_order = bot_admin_group_order
elif ctx.guild:
if is_guild_admin(ctx.author):
group_order = admin_group_order
elif ctx.guild_settings.mod_role.value in ctx.author.roles:
group_order = mod_group_order
else:
group_order = standard_group_order
else:
group_order = admin_group_order
help_embeds = []
for page_groups in group_order:
embed = discord.Embed(
description=header.format(ctx=ctx),
colour=discord.Colour(0x9b59b6),
title=title
)
for group in page_groups:
group_hint = group_hints.get(group, '').format(ctx=ctx)
group_str = stringy_cmd_groups.get(group, None)
if group_str:
embed.add_field(
name=group,
value="{}\n{}".format(group_hint, group_str).format(ctx=ctx),
inline=False
)
help_embeds.append(embed)
# Add the page numbers
for i, embed in enumerate(help_embeds):
embed.set_footer(text="Page {}/{}".format(i+1, len(help_embeds)))
# Send the embeds
if help_embeds:
await ctx.pager(help_embeds)
else:
await ctx.reply(
embed=discord.Embed(description=header, colour=discord.Colour(0x9b59b6))
)

View File

@@ -0,0 +1,50 @@
import discord
from cmdClient import cmdClient
from meta import client, conf
from .lib import guide_link, animation_link
message = """
Thank you for inviting me to your community.
Get started by typing `{prefix}help` to see my commands, and `{prefix}config info` \
to read about my configuration options!
To learn how to configure me and use all of my features, \
make sure to [click here]({guide_link}) to read our full setup guide.
Remember, if you need any help configuring me, \
want to suggest a feature, report a bug and stay updated, \
make sure to join our main support and study server by [clicking here]({support_link}).
Best of luck with your studies!
""".format(
guide_link=guide_link,
support_link=conf.bot.get('support_link'),
prefix=client.prefix
)
@client.add_after_event('guild_join', priority=0)
async def post_join_message(client: cmdClient, guild: discord.Guild):
try:
await guild.me.edit(nick="Leo")
except discord.HTTPException:
pass
if (channel := guild.system_channel) and channel.permissions_for(guild.me).embed_links:
embed = discord.Embed(
description=message
)
embed.set_author(
name="Hello everyone! My name is Leo, the StudyLion!",
icon_url="https://cdn.discordapp.com/emojis/933610591459872868.webp"
)
embed.set_image(url=animation_link)
try:
await channel.send(embed=embed)
except discord.HTTPException:
# Something went wrong sending the hi message
# Not much we can do about this
pass

View File

@@ -0,0 +1,5 @@
guide_link = "https://discord.studylions.com/tutorial"
animation_link = (
"https://media.discordapp.net/attachments/879412267731542047/926837189814419486/ezgif.com-resize.gif"
)

View File

@@ -0,0 +1,57 @@
import discord
from meta import conf
from LionContext import LionContext as Context
from .module import module
from .lib import guide_link
@module.cmd(
"support",
group="Meta",
desc=f"Have a question? Join my [support server]({conf.bot.get('support_link')})"
)
async def cmd_support(ctx: Context):
"""
Usage``:
{prefix}support
Description:
Replies with an invite link to my support server.
"""
await ctx.reply(
f"Click here to join my support server: {conf.bot.get('support_link')}"
)
@module.cmd(
"invite",
group="Meta",
desc=f"[Invite me]({conf.bot.get('invite_link')}) to your server so I can help your members stay productive!"
)
async def cmd_invite(ctx: Context):
"""
Usage``:
{prefix}invite
Description:
Replies with my invite link so you can add me to your server.
"""
embed = discord.Embed(
colour=discord.Colour.orange(),
description=f"Click here to add me to your server: {conf.bot.get('invite_link')}"
)
embed.add_field(
name="Setup tips",
value=(
"Remember to check out `{prefix}help` for the full command list, "
"and `{prefix}config info` for the configuration options.\n"
"[Click here]({guide}) for our comprehensive setup tutorial, and if you still have questions you can "
"join our support server [here]({support}) to talk to our friendly support team!"
).format(
prefix=ctx.best_prefix,
support=conf.bot.get('support_link'),
guide=guide_link
)
)
await ctx.reply(embed=embed)

View File

@@ -0,0 +1,3 @@
from LionModule import LionModule
module = LionModule("Meta")

View File

@@ -0,0 +1,144 @@
import datetime
import asyncio
import discord
import psutil
import sys
import gc
from data import NOTNULL
from data.queries import select_where
from utils.lib import prop_tabulate, utc_now
from LionContext import LionContext as Context
from .module import module
process = psutil.Process()
process.cpu_percent()
@module.cmd(
"nerd",
group="Meta",
desc="Information and statistics about me!"
)
async def cmd_nerd(ctx: Context):
"""
Usage``:
{prefix}nerd
Description:
View nerdy information and statistics about me!
"""
# Create embed
embed = discord.Embed(
colour=discord.Colour.orange(),
title="Nerd Panel",
description=(
"Hi! I'm [StudyLion]({studylion}), a study management bot owned by "
"[Ari Horesh]({ari}) and developed by [Conatum#5317]({cona}), with [contributors]({github})."
).format(
studylion="http://studylions.com/",
ari="https://arihoresh.com/",
cona="https://github.com/Intery",
github="https://github.com/StudyLions/StudyLion"
)
)
# ----- Study stats -----
# Current studying statistics
current_students, current_channels, current_guilds= (
ctx.client.data.current_sessions.select_one_where(
select_columns=(
"COUNT(*) AS studying_count",
"COUNT(DISTINCT(channelid)) AS channel_count",
"COUNT(DISTINCT(guildid)) AS guild_count"
)
)
)
# Past studying statistics
past_sessions, past_students, past_duration, past_guilds = ctx.client.data.session_history.select_one_where(
select_columns=(
"COUNT(*) AS session_count",
"COUNT(DISTINCT(userid)) AS user_count",
"SUM(duration) / 3600 AS total_hours",
"COUNT(DISTINCT(guildid)) AS guild_count"
)
)
# Tasklist statistics
tasks = ctx.client.data.tasklist.select_one_where(
select_columns=(
'COUNT(*)'
)
)[0]
tasks_completed = ctx.client.data.tasklist.select_one_where(
completed_at=NOTNULL,
select_columns=(
'COUNT(*)'
)
)[0]
# Timers
timer_count, timer_guilds = ctx.client.data.timers.select_one_where(
select_columns=("COUNT(*)", "COUNT(DISTINCT(guildid))")
)
study_fields = {
"Currently": f"`{current_students}` people working in `{current_channels}` rooms of `{current_guilds}` guilds",
"Recorded": f"`{past_duration}` hours from `{past_students}` people across `{past_sessions}` sessions",
"Tasks": f"`{tasks_completed}` out of `{tasks}` tasks completed",
"Timers": f"`{timer_count}` timers running in `{timer_guilds}` communities"
}
study_table = prop_tabulate(*zip(*study_fields.items()))
# ----- Shard statistics -----
shard_number = ctx.client.shard_id
shard_count = ctx.client.shard_count
guilds = len(ctx.client.guilds)
member_count = sum(guild.member_count for guild in ctx.client.guilds)
commands = len(ctx.client.cmds)
aliases = len(ctx.client.cmd_names)
dpy_version = discord.__version__
py_version = sys.version.split()[0]
data_version, data_time, _ = select_where(
"VersionHistory",
_extra="ORDER BY time DESC LIMIT 1"
)[0]
data_timestamp = int(data_time.replace(tzinfo=datetime.timezone.utc).timestamp())
shard_fields = {
"Shard": f"`{shard_number}` of `{shard_count}`",
"Guilds": f"`{guilds}` servers with `{member_count}` members (on this shard)",
"Commands": f"`{commands}` commands with `{aliases}` keywords",
"Version": f"`v{data_version}`, last updated <t:{data_timestamp}:F>",
"Py version": f"`{py_version}` running discord.py `{dpy_version}`"
}
shard_table = prop_tabulate(*zip(*shard_fields.items()))
# ----- Execution statistics -----
running_commands = len(ctx.client.active_contexts)
tasks = len(asyncio.all_tasks())
objects = len(gc.get_objects())
cpu_percent = process.cpu_percent()
mem_percent = int(process.memory_percent())
uptime = int(utc_now().timestamp() - process.create_time())
execution_fields = {
"Running": f"`{running_commands}` commands",
"Waiting for": f"`{tasks}` tasks to complete",
"Objects": f"`{objects}` loaded in memory",
"Usage": f"`{cpu_percent}%` CPU, `{mem_percent}%` MEM",
"Uptime": f"`{uptime // (24 * 3600)}` days, `{uptime // 3600 % 24:02}:{uptime // 60 % 60:02}:{uptime % 60:02}`"
}
execution_table = prop_tabulate(*zip(*execution_fields.items()))
# ----- Combine and output -----
embed.add_field(name="Study Stats", value=study_table, inline=False)
embed.add_field(name=f"Shard Info", value=shard_table, inline=False)
embed.add_field(name=f"Process Stats", value=execution_table, inline=False)
await ctx.reply(embed=embed)

View File

@@ -0,0 +1,9 @@
from .module import module
from . import data
from . import admin
from . import tickets
from . import video
from . import commands

View File

@@ -0,0 +1,109 @@
import discord
from settings import GuildSettings, GuildSetting
from wards import guild_admin
import settings
from .data import studyban_durations
@GuildSettings.attach_setting
class mod_log(settings.Channel, GuildSetting):
category = "Moderation"
attr_name = 'mod_log'
_data_column = 'mod_log_channel'
display_name = "mod_log"
desc = "Moderation event logging channel."
long_desc = (
"Channel to post moderation tickets.\n"
"These are produced when a manual or automatic moderation action is performed on a member. "
"This channel acts as a more context rich moderation history source than the audit log."
)
_chan_type = discord.ChannelType.text
@property
def success_response(self):
if self.value:
return "Moderation tickets will be posted to {}.".format(self.formatted)
else:
return "The moderation log has been unset."
@GuildSettings.attach_setting
class studyban_role(settings.Role, GuildSetting):
category = "Moderation"
attr_name = 'studyban_role'
_data_column = 'studyban_role'
display_name = "studyban_role"
desc = "The role given to members to prevent them from using server study features."
long_desc = (
"This role is to be given to members to prevent them from using the server's study features.\n"
"Typically, this role should act as a 'partial mute', and prevent the user from joining study voice channels, "
"or participating in study text channels.\n"
"It will be given automatically after study related offences, "
"such as not enabling video in the video-only channels."
)
@property
def success_response(self):
if self.value:
return "The study ban role is now {}.".format(self.formatted)
@GuildSettings.attach_setting
class studyban_durations(settings.SettingList, settings.ListData, settings.Setting):
category = "Moderation"
attr_name = 'studyban_durations'
_table_interface = studyban_durations
_id_column = 'guildid'
_data_column = 'duration'
_order_column = "rowid"
_default = [
5 * 60,
60 * 60,
6 * 60 * 60,
24 * 60 * 60,
168 * 60 * 60,
720 * 60 * 60
]
_setting = settings.Duration
write_ward = guild_admin
display_name = "studyban_durations"
desc = "Sequence of durations for automatic study bans."
long_desc = (
"This sequence describes how long a member will be automatically study-banned for "
"after committing a study-related offence (such as not enabling their video in video only channels).\n"
"If the sequence is `1d, 7d, 30d`, for example, the member will be study-banned "
"for `1d` on their first offence, `7d` on their second offence, and `30d` on their third. "
"On their fourth offence, they will not be unbanned.\n"
"This does not count pardoned offences."
)
accepts = (
"Comma separated list of durations in days/hours/minutes/seconds, for example `12h, 1d, 7d, 30d`."
)
# Flat cache, no need to expire objects
_cache = {}
@property
def success_response(self):
if self.value:
return "The automatic study ban durations are now {}.".format(self.formatted)
else:
return "Automatic study bans will never be reverted."

View File

@@ -0,0 +1,448 @@
"""
Shared commands for the moderation module.
"""
import asyncio
from collections import defaultdict
import discord
from cmdClient.lib import ResponseTimedOut
from wards import guild_moderator
from .module import module
from .tickets import Ticket, TicketType, TicketState
type_accepts = {
'note': TicketType.NOTE,
'notes': TicketType.NOTE,
'studyban': TicketType.STUDY_BAN,
'studybans': TicketType.STUDY_BAN,
'warn': TicketType.WARNING,
'warns': TicketType.WARNING,
'warning': TicketType.WARNING,
'warnings': TicketType.WARNING,
}
type_formatted = {
TicketType.NOTE: 'NOTE',
TicketType.STUDY_BAN: 'STUDYBAN',
TicketType.WARNING: 'WARNING',
}
type_summary_formatted = {
TicketType.NOTE: 'note',
TicketType.STUDY_BAN: 'studyban',
TicketType.WARNING: 'warning',
}
state_formatted = {
TicketState.OPEN: 'ACTIVE',
TicketState.EXPIRING: 'TEMP',
TicketState.EXPIRED: 'EXPIRED',
TicketState.PARDONED: 'PARDONED'
}
state_summary_formatted = {
TicketState.OPEN: 'Active',
TicketState.EXPIRING: 'Temporary',
TicketState.EXPIRED: 'Expired',
TicketState.REVERTED: 'Manually Reverted',
TicketState.PARDONED: 'Pardoned'
}
@module.cmd(
"tickets",
group="Moderation",
desc="View and filter the server moderation tickets.",
flags=('active', 'type=')
)
@guild_moderator()
async def cmd_tickets(ctx, flags):
"""
Usage``:
{prefix}tickets [@user] [--type <type>] [--active]
Description:
Display and optionally filter the moderation event history in this guild.
Flags::
type: Filter by ticket type. See **Ticket Types** below.
active: Only show in-effect tickets (i.e. hide expired and pardoned ones).
Ticket Types::
note: Moderation notes.
warn: Moderation warnings, both manual and automatic.
studyban: Bans from using study features from abusing the study system.
blacklist: Complete blacklisting from using my commands.
Ticket States::
Active: Active tickets that will not automatically expire.
Temporary: Active tickets that will automatically expire after a set duration.
Expired: Tickets that have automatically expired.
Reverted: Tickets with actions that have been reverted.
Pardoned: Tickets that have been pardoned and no longer apply to the user.
Examples:
{prefix}tickets {ctx.guild.owner.mention} --type warn --active
"""
# Parse filter fields
# First the user
if ctx.args:
userstr = ctx.args.strip('<@!&> ')
if not userstr.isdigit():
return await ctx.error_reply(
"**Usage:** `{prefix}tickets [@user] [--type <type>] [--active]`.\n"
"Please provide the `user` as a mention or id!".format(prefix=ctx.best_prefix)
)
filter_userid = int(userstr)
else:
filter_userid = None
if flags['type']:
typestr = flags['type'].lower()
if typestr not in type_accepts:
return await ctx.error_reply(
"Please see `{prefix}help tickets` for the valid ticket types!".format(prefix=ctx.best_prefix)
)
filter_type = type_accepts[typestr]
else:
filter_type = None
filter_active = flags['active']
# Build the filter arguments
filters = {'guildid': ctx.guild.id}
if filter_userid:
filters['targetid'] = filter_userid
if filter_type:
filters['ticket_type'] = filter_type
if filter_active:
filters['ticket_state'] = [TicketState.OPEN, TicketState.EXPIRING]
# Fetch the tickets with these filters
tickets = Ticket.fetch_tickets(**filters)
if not tickets:
if filters:
return await ctx.embed_reply("There are no tickets with these criteria!")
else:
return await ctx.embed_reply("There are no moderation tickets in this server!")
tickets = sorted(tickets, key=lambda ticket: ticket.data.guild_ticketid, reverse=True)
ticket_map = {ticket.data.guild_ticketid: ticket for ticket in tickets}
# Build the format string based on the filters
components = []
# Ticket id with link to message in mod log
components.append("[#{ticket.data.guild_ticketid}]({ticket.link})")
# Ticket creation date
components.append("<t:{timestamp:.0f}:d>")
# Ticket type, with current state
if filter_type is None:
if not filter_active:
components.append("`{ticket_type}{ticket_state}`")
else:
components.append("`{ticket_type}`")
elif not filter_active:
components.append("`{ticket_real_state}`")
if not filter_userid:
# Ticket user
components.append("<@{ticket.data.targetid}>")
if filter_userid or (filter_active and filter_type):
# Truncated ticket content
components.append("{content}")
format_str = ' | '.join(components)
# Break tickets into blocks
blocks = [tickets[i:i+10] for i in range(0, len(tickets), 10)]
# Build pages of tickets
ticket_pages = []
for block in blocks:
ticket_page = []
type_len = max(len(type_formatted[ticket.type]) for ticket in block)
state_len = max(len(state_formatted[ticket.state]) for ticket in block)
for ticket in block:
# First truncate content if required
content = ticket.data.content
if len(content) > 40:
content = content[:37] + '...'
# Build ticket line
line = format_str.format(
ticket=ticket,
timestamp=ticket.data.created_at.timestamp(),
ticket_type=type_formatted[ticket.type],
type_len=type_len,
ticket_state=" [{}]".format(state_formatted[ticket.state]) if ticket.state != TicketState.OPEN else '',
ticket_real_state=state_formatted[ticket.state],
state_len=state_len,
content=content
)
if ticket.state == TicketState.PARDONED:
line = "~~{}~~".format(line)
# Add to current page
ticket_page.append(line)
# Combine lines and add page to pages
ticket_pages.append('\n'.join(ticket_page))
# Build active ticket type summary
freq = defaultdict(int)
for ticket in tickets:
if ticket.state != TicketState.PARDONED:
freq[ticket.type] += 1
summary_pairs = [
(num, type_summary_formatted[ttype] + ('s' if num > 1 else ''))
for ttype, num in freq.items()
]
summary_pairs.sort(key=lambda pair: pair[0])
# num_len = max(len(str(num)) for num in freq.values())
# type_summary = '\n'.join(
# "**`{:<{}}`** {}".format(pair[0], num_len, pair[1])
# for pair in summary_pairs
# )
# # Build status summary
# freq = defaultdict(int)
# for ticket in tickets:
# freq[ticket.state] += 1
# num_len = max(len(str(num)) for num in freq.values())
# status_summary = '\n'.join(
# "**`{:<{}}`** {}".format(freq[state], num_len, state_str)
# for state, state_str in state_summary_formatted.items()
# if state in freq
# )
summary_strings = [
"**`{}`** {}".format(*pair) for pair in summary_pairs
]
if len(summary_strings) > 2:
summary = ', '.join(summary_strings[:-1]) + ', and ' + summary_strings[-1]
elif len(summary_strings) == 2:
summary = ' and '.join(summary_strings)
else:
summary = ''.join(summary_strings)
if summary:
summary += '.'
# Build embed info
title = "{}{}{}".format(
"Active " if filter_active else '',
"{} tickets ".format(type_formatted[filter_type]) if filter_type else "Tickets ",
(" for {}".format(ctx.guild.get_member(filter_userid) or filter_userid)
if filter_userid else " in {}".format(ctx.guild.name))
)
footer = "Click a ticket id to jump to it, or type the number to show the full ticket."
page_count = len(blocks)
if page_count > 1:
footer += "\nPage {{page_num}}/{}".format(page_count)
# Create embeds
embeds = [
discord.Embed(
title=title,
description="{}\n{}".format(summary, page),
colour=discord.Colour.orange(),
).set_footer(text=footer.format(page_num=i+1))
for i, page in enumerate(ticket_pages)
]
# Run output with cancellation and listener
out_msg = await ctx.pager(embeds, add_cancel=True)
old_task = _displays.pop((ctx.ch.id, ctx.author.id), None)
if old_task:
old_task.cancel()
_displays[(ctx.ch.id, ctx.author.id)] = display_task = asyncio.create_task(_ticket_display(ctx, ticket_map))
ctx.tasks.append(display_task)
await ctx.cancellable(out_msg, add_reaction=False)
_displays = {} # (channelid, userid) -> Task
async def _ticket_display(ctx, ticket_map):
"""
Display tickets when the ticket number is entered.
"""
current_ticket_msg = None
try:
while True:
# Wait for a number
try:
result = await ctx.client.wait_for(
"message",
check=lambda msg: (msg.author == ctx.author
and msg.channel == ctx.ch
and msg.content.isdigit()
and int(msg.content) in ticket_map),
timeout=60
)
except asyncio.TimeoutError:
return
# Delete the response
try:
await result.delete()
except discord.HTTPException:
pass
# Display the ticket
embed = ticket_map[int(result.content)].msg_args['embed']
if current_ticket_msg:
try:
await current_ticket_msg.edit(embed=embed)
except discord.HTTPException:
current_ticket_msg = None
if not current_ticket_msg:
try:
current_ticket_msg = await ctx.reply(embed=embed)
except discord.HTTPException:
return
asyncio.create_task(ctx.offer_delete(current_ticket_msg))
except asyncio.CancelledError:
if current_ticket_msg:
try:
await current_ticket_msg.delete()
except discord.HTTPException:
pass
@module.cmd(
"pardon",
group="Moderation",
desc="Pardon a ticket, or clear a member's moderation history.",
flags=('type=',)
)
@guild_moderator()
async def cmd_pardon(ctx, flags):
"""
Usage``:
{prefix}pardon ticketid, ticketid, ticketid
{prefix}pardon @user [--type <type>]
Description:
Marks the given tickets as no longer applicable.
These tickets will not be considered when calculating automod actions such as automatic study bans.
This may be used to mark warns or other tickets as no longer in-effect.
If the ticket is active when it is pardoned, it will be reverted, and any expiry cancelled.
Use the `{prefix}tickets` command to view the relevant tickets.
Flags::
type: Filter by ticket type. See **Ticket Types** in `{prefix}help tickets`.
Examples:
{prefix}pardon 21
{prefix}pardon {ctx.guild.owner.mention} --type warn
"""
usage = "**Usage**: `{prefix}pardon ticketid` or `{prefix}pardon @user`.".format(prefix=ctx.best_prefix)
if not ctx.args:
return await ctx.error_reply(
usage
)
# Parse provided tickets or filters
targetid = None
ticketids = []
args = {'guildid': ctx.guild.id}
if ',' in ctx.args:
# Assume provided numbers are ticketids.
items = [item.strip() for item in ctx.args.split(',')]
if not all(item.isdigit() for item in items):
return await ctx.error_reply(usage)
ticketids = [int(item) for item in items]
args['guild_ticketid'] = ticketids
else:
# Guess whether the provided numbers were ticketids or not
idstr = ctx.args.strip('<@!&> ')
if not idstr.isdigit():
return await ctx.error_reply(usage)
maybe_id = int(idstr)
if maybe_id > 4194304: # Testing whether it is greater than the minimum snowflake id
# Assume userid
targetid = maybe_id
args['targetid'] = maybe_id
# Add the type filter if provided
if flags['type']:
typestr = flags['type'].lower()
if typestr not in type_accepts:
return await ctx.error_reply(
"Please see `{prefix}help tickets` for the valid ticket types!".format(prefix=ctx.best_prefix)
)
args['ticket_type'] = type_accepts[typestr]
else:
# Assume guild ticketid
ticketids = [maybe_id]
args['guild_ticketid'] = maybe_id
# Fetch the matching tickets
tickets = Ticket.fetch_tickets(**args)
# Check whether we have the right selection of tickets
if targetid and not tickets:
return await ctx.error_reply(
"<@{}> has no matching tickets to pardon!"
)
if ticketids and len(ticketids) != len(tickets):
# Not all of the ticketids were valid
difference = list(set(ticketids).difference(ticket.ticketid for ticket in tickets))
if len(difference) == 1:
return await ctx.error_reply(
"Couldn't find ticket `{}`!".format(difference[0])
)
else:
return await ctx.error_reply(
"Couldn't find any of the following tickets:\n`{}`".format(
'`, `'.join(difference)
)
)
# Check whether there are any tickets left to pardon
to_pardon = [ticket for ticket in tickets if ticket.state != TicketState.PARDONED]
if not to_pardon:
if ticketids and len(tickets) == 1:
ticket = tickets[0]
return await ctx.error_reply(
"[Ticket #{}]({}) is already pardoned!".format(ticket.data.guild_ticketid, ticket.link)
)
else:
return await ctx.error_reply(
"All of these tickets are already pardoned!"
)
# We now know what tickets we want to pardon
# Request the pardon reason
try:
reason = await ctx.input("Please provide a reason for the pardon.")
except ResponseTimedOut:
raise ResponseTimedOut("Prompt timed out, no tickets were pardoned.")
# Pardon the tickets
for ticket in to_pardon:
await ticket.pardon(ctx.author, reason)
# Finally, ack the pardon
if targetid:
await ctx.embed_reply(
"The active {}s for <@{}> have been cleared.".format(
type_summary_formatted[args['ticket_type']] if flags['type'] else 'ticket',
targetid
)
)
elif len(to_pardon) == 1:
ticket = to_pardon[0]
await ctx.embed_reply(
"[Ticket #{}]({}) was pardoned.".format(
ticket.data.guild_ticketid,
ticket.link
)
)
else:
await ctx.embed_reply(
"The following tickets were pardoned.\n{}".format(
", ".join(
"[#{}]({})".format(ticket.data.guild_ticketid, ticket.link)
for ticket in to_pardon
)
)
)

View File

@@ -0,0 +1,19 @@
from data import Table, RowTable
studyban_durations = Table('studyban_durations')
ticket_info = RowTable(
'ticket_info',
('ticketid', 'guild_ticketid',
'guildid', 'targetid', 'ticket_type', 'ticket_state', 'moderator_id', 'auto',
'log_msg_id', 'created_at',
'content', 'context', 'addendum', 'duration',
'file_name', 'file_data',
'expiry',
'pardoned_by', 'pardoned_at', 'pardoned_reason'),
'ticketid',
cache_size=20000
)
tickets = Table('tickets')

View File

@@ -0,0 +1,4 @@
from cmdClient import Module
module = Module("Moderation")

View File

@@ -0,0 +1,486 @@
import asyncio
import logging
import traceback
import datetime
import discord
from meta import client
from data.conditions import THIS_SHARD
from settings import GuildSettings
from utils.lib import FieldEnum, strfdelta, utc_now
from .. import data
from ..module import module
class TicketType(FieldEnum):
"""
The possible ticket types.
"""
NOTE = 'NOTE', 'Note'
WARNING = 'WARNING', 'Warning'
STUDY_BAN = 'STUDY_BAN', 'Study Ban'
MESAGE_CENSOR = 'MESSAGE_CENSOR', 'Message Censor'
INVITE_CENSOR = 'INVITE_CENSOR', 'Invite Censor'
class TicketState(FieldEnum):
"""
The possible ticket states.
"""
OPEN = 'OPEN', "Active"
EXPIRING = 'EXPIRING', "Active"
EXPIRED = 'EXPIRED', "Expired"
PARDONED = 'PARDONED', "Pardoned"
REVERTED = 'REVERTED', "Reverted"
class Ticket:
"""
Abstract base class representing a Ticketed moderation action.
"""
# Type of event the class represents
_ticket_type = None # type: TicketType
_ticket_types = {} # Map: TicketType -> Ticket subclass
_expiry_tasks = {} # Map: ticketid -> expiry Task
def __init__(self, ticketid, *args, **kwargs):
self.ticketid = ticketid
@classmethod
async def create(cls, *args, **kwargs):
"""
Method used to create a new ticket of the current type.
Should add a row to the ticket table, post the ticket, and return the Ticket.
"""
raise NotImplementedError
@property
def data(self):
"""
Ticket row.
This will usually be a row of `ticket_info`.
"""
return data.ticket_info.fetch(self.ticketid)
@property
def guild(self):
return client.get_guild(self.data.guildid)
@property
def target(self):
guild = self.guild
return guild.get_member(self.data.targetid) if guild else None
@property
def msg_args(self):
"""
Ticket message posted in the moderation log.
"""
args = {}
# Build embed
info = self.data
member = self.target
name = str(member) if member else str(info.targetid)
if info.auto:
title_fmt = "Ticket #{} | {} | {}[Auto] | {}"
else:
title_fmt = "Ticket #{} | {} | {} | {}"
title = title_fmt.format(
info.guild_ticketid,
TicketState(info.ticket_state).desc,
TicketType(info.ticket_type).desc,
name
)
embed = discord.Embed(
title=title,
description=info.content,
timestamp=info.created_at
)
embed.add_field(
name="Target",
value="<@{}>".format(info.targetid)
)
if not info.auto:
embed.add_field(
name="Moderator",
value="<@{}>".format(info.moderator_id)
)
# if info.duration:
# value = "`{}` {}".format(
# strfdelta(datetime.timedelta(seconds=info.duration)),
# "(Expiry <t:{:.0f}>)".format(info.expiry.timestamp()) if info.expiry else ""
# )
# embed.add_field(
# name="Duration",
# value=value
# )
if info.expiry:
if info.ticket_state == TicketState.EXPIRING:
embed.add_field(
name="Expires at",
value="<t:{:.0f}>\n(Duration: `{}`)".format(
info.expiry.timestamp(),
strfdelta(datetime.timedelta(seconds=info.duration))
)
)
elif info.ticket_state == TicketState.EXPIRED:
embed.add_field(
name="Expired",
value="<t:{:.0f}>".format(
info.expiry.timestamp(),
)
)
else:
embed.add_field(
name="Expiry",
value="<t:{:.0f}>".format(
info.expiry.timestamp()
)
)
if info.context:
embed.add_field(
name="Context",
value=info.context,
inline=False
)
if info.addendum:
embed.add_field(
name="Notes",
value=info.addendum,
inline=False
)
if self.state == TicketState.PARDONED:
embed.add_field(
name="Pardoned",
value=(
"Pardoned by <@{}> at <t:{:.0f}>.\n{}"
).format(
info.pardoned_by,
info.pardoned_at.timestamp(),
info.pardoned_reason or ""
),
inline=False
)
embed.set_footer(text="ID: {}".format(info.targetid))
args['embed'] = embed
# Add file
if info.file_name:
args['file'] = discord.File(info.file_data, info.file_name)
return args
@property
def link(self):
"""
The link to the ticket in the moderation log.
"""
info = self.data
modlog = GuildSettings(info.guildid).mod_log.data
return 'https://discord.com/channels/{}/{}/{}'.format(
info.guildid,
modlog,
info.log_msg_id
)
@property
def state(self):
return TicketState(self.data.ticket_state)
@property
def type(self):
return TicketType(self.data.ticket_type)
async def update(self, **kwargs):
"""
Update ticket fields.
"""
fields = (
'targetid', 'moderator_id', 'auto', 'log_msg_id',
'content', 'expiry', 'ticket_state',
'context', 'addendum', 'duration', 'file_name', 'file_data',
'pardoned_by', 'pardoned_at', 'pardoned_reason',
)
params = {field: kwargs[field] for field in fields if field in kwargs}
if params:
data.ticket_info.update_where(params, ticketid=self.ticketid)
await self.update_expiry()
await self.post()
async def post(self):
"""
Post or update the ticket in the moderation log.
Also updates the saved message id.
"""
info = self.data
modlog = GuildSettings(info.guildid).mod_log.value
if not modlog:
return
resend = True
try:
if info.log_msg_id:
# Try to fetch the message
message = await modlog.fetch_message(info.log_msg_id)
if message:
if message.author.id == client.user.id:
# TODO: Handle file edit
await message.edit(embed=self.msg_args['embed'])
resend = False
else:
try:
await message.delete()
except discord.HTTPException:
pass
if resend:
message = await modlog.send(**self.msg_args)
self.data.log_msg_id = message.id
except discord.HTTPException:
client.log(
"Cannot post ticket (tid: {}) due to discord exception or issue.".format(self.ticketid)
)
except Exception:
# This should never happen in normal operation
client.log(
"Error while posting ticket (tid:{})! "
"Exception traceback follows.\n{}".format(
self.ticketid,
traceback.format_exc()
),
context="TICKETS",
level=logging.ERROR
)
@classmethod
def load_expiring(cls):
"""
Load and schedule all expiring tickets.
"""
# TODO: Consider changing this to a flat timestamp system, to avoid storing lots of coroutines.
# TODO: Consider only scheduling the expiries in the next day, and updating this once per day.
# TODO: Only fetch tickets from guilds we are in.
# Cancel existing expiry tasks
for task in cls._expiry_tasks.values():
if not task.done():
task.cancel()
# Get all expiring tickets
expiring_rows = data.tickets.select_where(
ticket_state=TicketState.EXPIRING,
guildid=THIS_SHARD
)
# Create new expiry tasks
now = utc_now()
cls._expiry_tasks = {
row['ticketid']: asyncio.create_task(
cls._schedule_expiry_for(
row['ticketid'],
(row['expiry'] - now).total_seconds()
)
) for row in expiring_rows
}
# Log
client.log(
"Loaded {} expiring tickets.".format(len(cls._expiry_tasks)),
context="TICKET_LOADER",
)
@classmethod
async def _schedule_expiry_for(cls, ticketid, delay):
"""
Schedule expiry for a given ticketid
"""
try:
await asyncio.sleep(delay)
ticket = Ticket.fetch(ticketid)
if ticket:
await asyncio.shield(ticket._expire())
except asyncio.CancelledError:
return
def update_expiry(self):
# Cancel any existing expiry task
task = self._expiry_tasks.pop(self.ticketid, None)
if task and not task.done():
task.cancel()
# Schedule a new expiry task, if applicable
if self.data.ticket_state == TicketState.EXPIRING:
self._expiry_tasks[self.ticketid] = asyncio.create_task(
self._schedule_expiry_for(
self.ticketid,
(self.data.expiry - utc_now()).total_seconds()
)
)
async def cancel_expiry(self):
"""
Cancel ticket expiry.
In particular, may be used if another ticket overrides `self`.
Sets the ticket state to `OPEN`, so that it no longer expires.
"""
if self.state == TicketState.EXPIRING:
# Update the ticket state
self.data.ticket_state = TicketState.OPEN
# Remove from expiry tsks
self.update_expiry()
# Repost
await self.post()
async def _revert(self, reason=None):
"""
Method used to revert the ticket action, e.g. unban or remove mute role.
Generally called by `pardon` and `_expire`.
May be overriden by the Ticket type, if they implement any revert logic.
Is a no-op by default.
"""
return
async def _expire(self):
"""
Method to automatically expire a ticket.
May be overriden by the Ticket type for more complex expiry logic.
Must set `data.ticket_state` to `EXPIRED` if applicable.
"""
if self.state == TicketState.EXPIRING:
client.log(
"Automatically expiring ticket (tid:{}).".format(self.ticketid),
context="TICKETS"
)
try:
await self._revert(reason="Automatic Expiry")
except Exception:
# This should never happen in normal operation
client.log(
"Error while expiring ticket (tid:{})! "
"Exception traceback follows.\n{}".format(
self.ticketid,
traceback.format_exc()
),
context="TICKETS",
level=logging.ERROR
)
# Update state
self.data.ticket_state = TicketState.EXPIRED
# Update log message
await self.post()
# Post a note to the modlog
modlog = GuildSettings(self.data.guildid).mod_log.value
if modlog:
try:
await modlog.send(
embed=discord.Embed(
colour=discord.Colour.orange(),
description="[Ticket #{}]({}) expired!".format(self.data.guild_ticketid, self.link)
)
)
except discord.HTTPException:
pass
async def pardon(self, moderator, reason, timestamp=None):
"""
Pardon process for the ticket.
May be overidden by the Ticket type for more complex pardon logic.
Must set `data.ticket_state` to `PARDONED` if applicable.
"""
if self.state != TicketState.PARDONED:
if self.state in (TicketState.OPEN, TicketState.EXPIRING):
try:
await self._revert(reason="Pardoned by {}".format(moderator.id))
except Exception:
# This should never happen in normal operation
client.log(
"Error while pardoning ticket (tid:{})! "
"Exception traceback follows.\n{}".format(
self.ticketid,
traceback.format_exc()
),
context="TICKETS",
level=logging.ERROR
)
# Update state
with self.data.batch_update():
self.data.ticket_state = TicketState.PARDONED
self.data.pardoned_at = utc_now()
self.data.pardoned_by = moderator.id
self.data.pardoned_reason = reason
# Update (i.e. remove) expiry
self.update_expiry()
# Update log message
await self.post()
@classmethod
def fetch_tickets(cls, *ticketids, **kwargs):
"""
Fetch tickets matching the given criteria (passed transparently to `select_where`).
Positional arguments are treated as `ticketids`, which are not supported in keyword arguments.
"""
if ticketids:
kwargs['ticketid'] = ticketids
# Set the ticket type to the class type if not specified
if cls._ticket_type and 'ticket_type' not in kwargs:
kwargs['ticket_type'] = cls._ticket_type
# This is actually mainly for caching, since we don't pass the data to the initialiser
rows = data.ticket_info.fetch_rows_where(
**kwargs
)
return [
cls._ticket_types[TicketType(row.ticket_type)](row.ticketid)
for row in rows
]
@classmethod
def fetch(cls, ticketid):
"""
Return the Ticket with the given id, if found, or `None` otherwise.
"""
tickets = cls.fetch_tickets(ticketid)
return tickets[0] if tickets else None
@classmethod
def register_ticket_type(cls, ticket_cls):
"""
Decorator to register a new Ticket subclass as a ticket type.
"""
cls._ticket_types[ticket_cls._ticket_type] = ticket_cls
return ticket_cls
@module.launch_task
async def load_expiring_tickets(client):
Ticket.load_expiring()

View File

@@ -0,0 +1,4 @@
from .Ticket import Ticket, TicketType, TicketState
from .studybans import StudyBanTicket
from .notes import NoteTicket
from .warns import WarnTicket

View File

@@ -0,0 +1,112 @@
"""
Note ticket implementation.
Guild moderators can add a note about a user, visible in their moderation history.
Notes appear in the moderation log and the user's ticket history, like any other ticket.
This module implements the Note TicketType and the `note` moderation command.
"""
from cmdClient.lib import ResponseTimedOut
from wards import guild_moderator
from ..module import module
from ..data import tickets
from .Ticket import Ticket, TicketType, TicketState
@Ticket.register_ticket_type
class NoteTicket(Ticket):
_ticket_type = TicketType.NOTE
@classmethod
async def create(cls, guildid, targetid, moderatorid, content, **kwargs):
"""
Create a new Note on a target.
`kwargs` are passed transparently to the table insert method.
"""
ticket_row = tickets.insert(
guildid=guildid,
targetid=targetid,
ticket_type=cls._ticket_type,
ticket_state=TicketState.OPEN,
moderator_id=moderatorid,
auto=False,
content=content,
**kwargs
)
# Create the note ticket
ticket = cls(ticket_row['ticketid'])
# Post the ticket and return
await ticket.post()
return ticket
@module.cmd(
"note",
group="Moderation",
desc="Add a Note to a member's record."
)
@guild_moderator()
async def cmd_note(ctx):
"""
Usage``:
{prefix}note @target
{prefix}note @target <content>
Description:
Add a note to the target's moderation record.
The note will appear in the moderation log and in the `tickets` command.
The `target` must be specificed by mention or user id.
If the `content` is not given, it will be prompted for.
Example:
{prefix}note {ctx.author.mention} Seen reading the `note` documentation.
"""
if not ctx.args:
return await ctx.error_reply(
"**Usage:** `{}note @target <content>`.".format(ctx.best_prefix)
)
# Extract the target. We don't require them to be in the server
splits = ctx.args.split(maxsplit=1)
target_str = splits[0].strip('<@!&> ')
if not target_str.isdigit():
return await ctx.error_reply(
"**Usage:** `{}note @target <content>`.\n"
"`target` must be provided by mention or userid.".format(ctx.best_prefix)
)
targetid = int(target_str)
# Extract or prompt for the content
if len(splits) != 2:
try:
content = await ctx.input("What note would you like to add?", timeout=300)
except ResponseTimedOut:
raise ResponseTimedOut("Prompt timed out, no note was created.")
else:
content = splits[1].strip()
# Create the note ticket
ticket = await NoteTicket.create(
ctx.guild.id,
targetid,
ctx.author.id,
content
)
if ticket.data.log_msg_id:
await ctx.embed_reply(
"Note on <@{}> created as [Ticket #{}]({}).".format(
targetid,
ticket.data.guild_ticketid,
ticket.link
)
)
else:
await ctx.embed_reply(
"Note on <@{}> created as Ticket #{}.".format(targetid, ticket.data.guild_ticketid)
)

View File

@@ -0,0 +1,126 @@
import datetime
import discord
from meta import client
from utils.lib import utc_now
from settings import GuildSettings
from data import NOT
from .. import data
from .Ticket import Ticket, TicketType, TicketState
@Ticket.register_ticket_type
class StudyBanTicket(Ticket):
_ticket_type = TicketType.STUDY_BAN
@classmethod
async def create(cls, guildid, targetid, moderatorid, reason, expiry=None, **kwargs):
"""
Create a new study ban ticket.
"""
# First create the ticket itself
ticket_row = data.tickets.insert(
guildid=guildid,
targetid=targetid,
ticket_type=cls._ticket_type,
ticket_state=TicketState.EXPIRING if expiry else TicketState.OPEN,
moderator_id=moderatorid,
auto=(moderatorid == client.user.id),
content=reason,
expiry=expiry,
**kwargs
)
# Create the Ticket
ticket = cls(ticket_row['ticketid'])
# Schedule ticket expiry, if applicable
if expiry:
ticket.update_expiry()
# Cancel any existing studyban expiry for this member
tickets = cls.fetch_tickets(
guildid=guildid,
ticketid=NOT(ticket_row['ticketid']),
targetid=targetid,
ticket_state=TicketState.EXPIRING
)
for ticket in tickets:
await ticket.cancel_expiry()
# Post the ticket
await ticket.post()
# Return the ticket
return ticket
async def _revert(self, reason=None):
"""
Revert the studyban by removing the role.
"""
guild_settings = GuildSettings(self.data.guildid)
role = guild_settings.studyban_role.value
target = self.target
if target and role:
try:
await target.remove_roles(
role,
reason="Reverting StudyBan: {}".format(reason)
)
except discord.HTTPException:
# TODO: Error log?
...
@classmethod
async def autoban(cls, guild, target, reason, **kwargs):
"""
Convenience method to automatically studyban a member, for the configured duration.
If the role is set, this will create and return a `StudyBanTicket` regardless of whether the
studyban was successful.
If the role is not set, or the ticket cannot be created, this will return `None`.
"""
# Get the studyban role, fail if there isn't one set, or the role doesn't exist
guild_settings = GuildSettings(guild.id)
role = guild_settings.studyban_role.value
if not role:
return None
# Attempt to add the role, record failure
try:
await target.add_roles(role, reason="Applying StudyBan: {}".format(reason[:400]))
except discord.HTTPException:
role_failed = True
else:
role_failed = False
# Calculate the applicable automatic duration and expiry
# First count the existing non-pardoned studybans for this target
studyban_count = data.tickets.select_one_where(
guildid=guild.id,
targetid=target.id,
ticket_type=cls._ticket_type,
ticket_state=NOT(TicketState.PARDONED),
select_columns=('COUNT(*)',)
)[0]
studyban_count = int(studyban_count)
# Then read the guild setting to find the applicable duration
studyban_durations = guild_settings.studyban_durations.value
if studyban_count < len(studyban_durations):
duration = studyban_durations[studyban_count]
expiry = utc_now() + datetime.timedelta(seconds=duration)
else:
duration = None
expiry = None
# Create the ticket and return
if role_failed:
kwargs['addendum'] = '\n'.join((
kwargs.get('addendum', ''),
"Could not add the studyban role! Please add the role manually and check my permissions."
))
return await cls.create(
guild.id, target.id, client.user.id, reason, duration=duration, expiry=expiry, **kwargs
)

View File

@@ -0,0 +1,153 @@
"""
Warn ticket implementation.
Guild moderators can officially warn a user via command.
This DMs the users with the warning.
"""
import datetime
import discord
from cmdClient.lib import ResponseTimedOut
from wards import guild_moderator
from ..module import module
from ..data import tickets
from .Ticket import Ticket, TicketType, TicketState
@Ticket.register_ticket_type
class WarnTicket(Ticket):
_ticket_type = TicketType.WARNING
@classmethod
async def create(cls, guildid, targetid, moderatorid, content, **kwargs):
"""
Create a new Warning for the target.
`kwargs` are passed transparently to the table insert method.
"""
ticket_row = tickets.insert(
guildid=guildid,
targetid=targetid,
ticket_type=cls._ticket_type,
ticket_state=TicketState.OPEN,
moderator_id=moderatorid,
content=content,
**kwargs
)
# Create the note ticket
ticket = cls(ticket_row['ticketid'])
# Post the ticket and return
await ticket.post()
return ticket
async def _revert(*args, **kwargs):
# Warnings don't have a revert process
pass
@module.cmd(
"warn",
group="Moderation",
desc="Officially warn a user for a misbehaviour."
)
@guild_moderator()
async def cmd_warn(ctx):
"""
Usage``:
{prefix}warn @target
{prefix}warn @target <reason>
Description:
The `target` must be specificed by mention or user id.
If the `reason` is not given, it will be prompted for.
Example:
{prefix}warn {ctx.author.mention} Don't actually read the documentation!
"""
if not ctx.args:
return await ctx.error_reply(
"**Usage:** `{}warn @target <reason>`.".format(ctx.best_prefix)
)
# Extract the target. We do require them to be in the server
splits = ctx.args.split(maxsplit=1)
target_str = splits[0].strip('<@!&> ')
if not target_str.isdigit():
return await ctx.error_reply(
"**Usage:** `{}warn @target <reason>`.\n"
"`target` must be provided by mention or userid.".format(ctx.best_prefix)
)
targetid = int(target_str)
target = ctx.guild.get_member(targetid)
if not target:
return await ctx.error_reply("Cannot warn a user who is not in the server!")
# Extract or prompt for the content
if len(splits) != 2:
try:
content = await ctx.input("Please give a reason for this warning!", timeout=300)
except ResponseTimedOut:
raise ResponseTimedOut("Prompt timed out, the member was not warned.")
else:
content = splits[1].strip()
# Create the warn ticket
ticket = await WarnTicket.create(
ctx.guild.id,
targetid,
ctx.author.id,
content
)
# Attempt to message the member
embed = discord.Embed(
title="You have received a warning!",
description=(
content
),
colour=discord.Colour.red(),
timestamp=datetime.datetime.utcnow()
)
embed.add_field(
name="Info",
value=(
"*Warnings appear in your moderation history. "
"Failure to comply, or repeated warnings, "
"may result in muting, studybanning, or server banning.*"
)
)
embed.set_footer(
icon_url=ctx.guild.icon_url,
text=ctx.guild.name
)
dm_msg = None
try:
dm_msg = await target.send(embed=embed)
except discord.HTTPException:
pass
# Get previous warnings
count = tickets.select_one_where(
guildid=ctx.guild.id,
targetid=targetid,
ticket_type=TicketType.WARNING,
ticket_state=[TicketState.OPEN, TicketState.EXPIRING],
select_columns=('COUNT(*)',)
)[0]
if count == 1:
prev_str = "This is their first warning."
else:
prev_str = "They now have `{}` warnings.".format(count)
await ctx.embed_reply(
"[Ticket #{}]({}): {} has been warned. {}\n{}".format(
ticket.data.guild_ticketid,
ticket.link,
target.mention,
prev_str,
"*Could not DM the user their warning!*" if not dm_msg else ''
)
)

View File

@@ -0,0 +1,4 @@
from . import data
from . import admin
from . import watchdog

View File

@@ -0,0 +1,128 @@
from collections import defaultdict
from settings import GuildSettings, GuildSetting
from wards import guild_admin
import settings
from .data import video_channels
@GuildSettings.attach_setting
class video_channels(settings.ChannelList, settings.ListData, settings.Setting):
category = "Video Channels"
attr_name = 'video_channels'
_table_interface = video_channels
_id_column = 'guildid'
_data_column = 'channelid'
_setting = settings.VoiceChannel
write_ward = guild_admin
display_name = "video_channels"
desc = "Channels where members are required to enable their video."
_force_unique = True
long_desc = (
"Members must keep their video enabled in these channels.\n"
"If they do not keep their video enabled, they will be asked to enable it in their DMS after `15` seconds, "
"and then kicked from the channel with another warning after the `video_grace_period` duration has passed.\n"
"After the first offence, if the `video_studyban` is enabled and the `studyban_role` is set, "
"they will also be automatically studybanned."
)
# Flat cache, no need to expire objects
_cache = {}
@property
def success_response(self):
if self.value:
return "Members must enable their video in the following channels:\n{}".format(self.formatted)
else:
return "There are no video-required channels set up."
@classmethod
async def launch_task(cls, client):
"""
Launch initialisation step for the `video_channels` setting.
Pre-fill cache for the guilds with currently active voice channels.
"""
active_guildids = [
guild.id
for guild in client.guilds
if any(channel.members for channel in guild.voice_channels)
]
if active_guildids:
cache = {guildid: [] for guildid in active_guildids}
rows = cls._table_interface.select_where(
guildid=active_guildids
)
for row in rows:
cache[row['guildid']].append(row['channelid'])
cls._cache.update(cache)
@GuildSettings.attach_setting
class video_studyban(settings.Boolean, GuildSetting):
category = "Video Channels"
attr_name = 'video_studyban'
_data_column = 'video_studyban'
display_name = "video_studyban"
desc = "Whether to studyban members if they don't enable their video."
long_desc = (
"If enabled, members who do not enable their video in the configured `video_channels` will be "
"study-banned after a single warning.\n"
"When disabled, members will only be warned and removed from the channel."
)
_default = True
_outputs = {True: "Enabled", False: "Disabled"}
@property
def success_response(self):
if self.value:
return "Members will now be study-banned if they don't enable their video in the configured video channels."
else:
return "Members will not be study-banned if they don't enable their video in video channels."
@GuildSettings.attach_setting
class video_grace_period(settings.Duration, GuildSetting):
category = "Video Channels"
attr_name = 'video_grace_period'
_data_column = 'video_grace_period'
display_name = "video_grace_period"
desc = "How long to wait before kicking/studybanning members who don't enable their video."
long_desc = (
"The period after a member has been asked to enable their video in a video-only channel "
"before they will be kicked from the channel, and warned or studybanned (if enabled)."
)
_default = 90
_default_multiplier = 1
@classmethod
def _format_data(cls, id: int, data, **kwargs):
"""
Return the string version of the data.
"""
if data is None:
return None
else:
return "`{} seconds`".format(data)
@property
def success_response(self):
return (
"Members who do not enable their video will "
"be disconnected after {}.".format(self.formatted)
)

Some files were not shown because too many files have changed in this diff Show More