rewrite: Restructure to include GUI.
This commit is contained in:
5
src/analytics/__init__.py
Normal file
5
src/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .cog import Analytics
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(Analytics(bot))
|
||||
177
src/analytics/cog.py
Normal file
177
src/analytics/cog.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
|
||||
|
||||
from meta import LionCog, LionBot, LionContext
|
||||
from meta.app import shard_talk, appname
|
||||
from meta.errors import HandledException, SafeCancellation
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
|
||||
from .data import AnalyticsData
|
||||
from .events import (
|
||||
CommandStatus, CommandEvent, command_event_handler,
|
||||
GuildAction, GuildEvent, guild_event_handler,
|
||||
VoiceAction, VoiceEvent, voice_event_handler
|
||||
)
|
||||
from .snapshot import shard_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Client side might be better handled as a single connection fed by a queue?
|
||||
# Maybe consider this again after the interactive REPL idea
|
||||
# Or if it seems like this is giving an absurd amount of traffic
|
||||
|
||||
|
||||
class Analytics(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(AnalyticsData())
|
||||
self.an_app = bot.config.analytics['appname']
|
||||
|
||||
self.talk_command_event = command_event_handler.bind(shard_talk).route
|
||||
self.talk_guild_event = guild_event_handler.bind(shard_talk).route
|
||||
self.talk_voice_event = voice_event_handler.bind(shard_talk).route
|
||||
|
||||
self.talk_shard_snapshot = shard_talk.register_route()(shard_snapshot)
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
@LionCog.listener()
|
||||
@log_wrap(action='AnEvent')
|
||||
async def on_voice_state_update(self, member, before, after):
|
||||
if not before.channel and after.channel:
|
||||
# Member joined channel
|
||||
action = VoiceAction.JOINED
|
||||
elif before.channel and not after.channel:
|
||||
# Member left channel
|
||||
action = VoiceAction.LEFT
|
||||
else:
|
||||
# Member change state, we don't need to deal with that
|
||||
return
|
||||
|
||||
event = VoiceEvent(
|
||||
appname=appname,
|
||||
userid=member.id,
|
||||
guildid=member.guild.id,
|
||||
action=action,
|
||||
created_at=utc_now()
|
||||
)
|
||||
if self.an_app not in shard_talk.peers:
|
||||
logger.warning(f"Analytics peer not found, discarding event: {event}")
|
||||
else:
|
||||
await self.talk_voice_event(event).send(self.an_app, wait_for_reply=False)
|
||||
|
||||
@LionCog.listener()
|
||||
@log_wrap(action='AnEvent')
|
||||
async def on_guild_join(self, guild):
|
||||
"""
|
||||
Send guild join event.
|
||||
"""
|
||||
event = GuildEvent(
|
||||
appname=appname,
|
||||
guildid=guild.id,
|
||||
action=GuildAction.JOINED,
|
||||
created_at=utc_now()
|
||||
)
|
||||
if self.an_app not in shard_talk.peers:
|
||||
logger.warning(f"Analytics peer not found, discarding event: {event}")
|
||||
else:
|
||||
await self.talk_guild_event(event).send(self.an_app, wait_for_reply=False)
|
||||
|
||||
@LionCog.listener()
|
||||
@log_wrap(action='AnEvent')
|
||||
async def on_guild_remove(self, guild):
|
||||
"""
|
||||
Send guild leave event
|
||||
"""
|
||||
event = GuildEvent(
|
||||
appname=appname,
|
||||
guildid=guild.id,
|
||||
action=GuildAction.LEFT,
|
||||
created_at=utc_now()
|
||||
)
|
||||
if self.an_app not in shard_talk.peers:
|
||||
logger.warning(f"Analytics peer not found, discarding event: {event}")
|
||||
else:
|
||||
await self.talk_guild_event(event).send(self.an_app, wait_for_reply=False)
|
||||
|
||||
@LionCog.listener()
|
||||
@log_wrap(action='AnEvent')
|
||||
async def on_command_completion(self, ctx: LionContext):
|
||||
"""
|
||||
Send command completed successfully.
|
||||
"""
|
||||
duration = utc_now() - ctx.message.created_at
|
||||
event = CommandEvent(
|
||||
appname=appname,
|
||||
cmdname=ctx.command.name if ctx.command else 'Unknown',
|
||||
cogname=ctx.cog.qualified_name if ctx.cog else None,
|
||||
userid=ctx.author.id,
|
||||
created_at=utc_now(),
|
||||
status=CommandStatus.COMPLETED,
|
||||
execution_time=duration.total_seconds(),
|
||||
guildid=ctx.guild.id if ctx.guild else None,
|
||||
ctxid=ctx.message.id
|
||||
)
|
||||
if self.an_app not in shard_talk.peers:
|
||||
logger.warning(f"Analytics peer not found, discarding event: {event}")
|
||||
else:
|
||||
await self.talk_command_event(event).send(self.an_app, wait_for_reply=False)
|
||||
|
||||
@LionCog.listener()
|
||||
@log_wrap(action='AnEvent')
|
||||
async def on_command_error(self, ctx: LionContext, error):
|
||||
"""
|
||||
Send command failed.
|
||||
"""
|
||||
duration = utc_now() - ctx.message.created_at
|
||||
status = CommandStatus.FAILED
|
||||
err_type = None
|
||||
try:
|
||||
err_type = repr(error)
|
||||
raise error
|
||||
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||
original = error.original
|
||||
try:
|
||||
err_type = repr(original)
|
||||
if isinstance(original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||
raise original.original
|
||||
else:
|
||||
raise original
|
||||
except HandledException:
|
||||
status = CommandStatus.CANCELLED
|
||||
except SafeCancellation:
|
||||
status = CommandStatus.CANCELLED
|
||||
except discord.Forbidden:
|
||||
status = CommandStatus.CANCELLED
|
||||
except discord.HTTPException:
|
||||
status = CommandStatus.CANCELLED
|
||||
except Exception:
|
||||
status = CommandStatus.FAILED
|
||||
except CheckFailure:
|
||||
status = CommandStatus.CANCELLED
|
||||
except Exception:
|
||||
status = CommandStatus.FAILED
|
||||
|
||||
event = CommandEvent(
|
||||
appname=appname,
|
||||
cmdname=ctx.command.name if ctx.command else 'Unknown',
|
||||
cogname=ctx.cog.qualified_name if ctx.cog else None,
|
||||
userid=ctx.author.id,
|
||||
created_at=utc_now(),
|
||||
status=status,
|
||||
error=err_type,
|
||||
execution_time=duration.total_seconds(),
|
||||
guildid=ctx.guild.id if ctx.guild else None,
|
||||
ctxid=ctx.message.id
|
||||
)
|
||||
if self.an_app not in shard_talk.peers:
|
||||
logger.warning(f"Analytics peer not found, discarding event: {event}")
|
||||
else:
|
||||
await self.talk_command_event(event).send(self.an_app, wait_for_reply=False)
|
||||
189
src/analytics/data.py
Normal file
189
src/analytics/data.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from enum import Enum
|
||||
|
||||
from data.registry import Registry
|
||||
from data.adapted import RegisterEnum
|
||||
from data.models import RowModel
|
||||
from data.columns import Integer, String, Timestamp, Column
|
||||
|
||||
|
||||
class CommandStatus(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE analytics.CommandStatus AS ENUM(
|
||||
'COMPLETED',
|
||||
'CANCELLED'
|
||||
'FAILED'
|
||||
);
|
||||
"""
|
||||
COMPLETED = ('COMPLETED',)
|
||||
CANCELLED = ('CANCELLED',)
|
||||
FAILED = ('FAILED',)
|
||||
|
||||
|
||||
class GuildAction(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE analytics.GuildAction AS ENUM(
|
||||
'JOINED',
|
||||
'LEFT'
|
||||
);
|
||||
"""
|
||||
JOINED = ('JOINED',)
|
||||
LEFT = ('LEFT',)
|
||||
|
||||
|
||||
class VoiceAction(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE analytics.VoiceAction AS ENUM(
|
||||
'JOINED',
|
||||
'LEFT'
|
||||
);
|
||||
"""
|
||||
JOINED = ('JOINED',)
|
||||
LEFT = ('LEFT',)
|
||||
|
||||
|
||||
class AnalyticsData(Registry, name='analytics'):
|
||||
CommandStatus = RegisterEnum(CommandStatus, name="analytics.CommandStatus")
|
||||
GuildAction = RegisterEnum(GuildAction, name="analytics.GuildAction")
|
||||
VoiceAction = RegisterEnum(VoiceAction, name="analytics.VoiceAction")
|
||||
|
||||
class Snapshots(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.snapshots(
|
||||
snapshotid SERIAL PRIMARY KEY,
|
||||
appname TEXT NOT NULL REFERENCES bot_config (appname),
|
||||
guild_count INTEGER NOT NULL,
|
||||
member_count INTEGER NOT NULL,
|
||||
user_count INTEGER NOT NULL,
|
||||
in_voice INTEGER NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
|
||||
);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'snapshots'
|
||||
|
||||
snapshotid = Integer(primary=True)
|
||||
appname = String()
|
||||
guild_count = Integer()
|
||||
member_count = Integer()
|
||||
user_count = Integer()
|
||||
in_voice = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
class Events(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.events(
|
||||
eventid SERIAL PRIMARY KEY,
|
||||
appname TEXT NOT NULL REFERENCES bot_config (appname),
|
||||
ctxid BIGINT,
|
||||
guildid BIGINT,
|
||||
_created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
|
||||
);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'events'
|
||||
|
||||
eventid = Integer(primary=True)
|
||||
appname = String()
|
||||
ctxid = Integer()
|
||||
guildid = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
class Commands(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.commands(
|
||||
cmdname TEXT NOT NULL,
|
||||
cogname TEXT,
|
||||
userid BIGINT NOT NULL,
|
||||
status analytics.CommandStatus NOT NULL,
|
||||
execution_time REAL NOT NULL
|
||||
) INHERITS (analytics.events);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'commands'
|
||||
|
||||
eventid = Integer(primary=True)
|
||||
appname = String()
|
||||
ctxid = Integer()
|
||||
guildid = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
cmdname = String()
|
||||
cogname = String()
|
||||
userid = Integer()
|
||||
status: Column[CommandStatus] = Column()
|
||||
error = String()
|
||||
execution_time: Column[float] = Column()
|
||||
|
||||
class Guilds(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.guilds(
|
||||
guildid BIGINT NOT NULL,
|
||||
action analytics.GuildAction NOT NULL
|
||||
) INHERITS (analytics.events);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'guilds'
|
||||
|
||||
eventid = Integer(primary=True)
|
||||
appname = String()
|
||||
ctxid = Integer()
|
||||
guildid = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
action: Column[GuildAction] = Column()
|
||||
|
||||
class VoiceSession(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.voice_sessions(
|
||||
userid BIGINT NOT NULL,
|
||||
action analytics.VoiceAction NOT NULL
|
||||
) INHERITS (analytics.events);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'voice_sessions'
|
||||
|
||||
eventid = Integer(primary=True)
|
||||
appname = String()
|
||||
ctxid = Integer()
|
||||
guildid = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
userid = Integer()
|
||||
action: Column[GuildAction] = Column()
|
||||
|
||||
class GuiRender(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE analytics.gui_renders(
|
||||
cardname TEXT NOT NULL,
|
||||
duration INTEGER NOT NULL
|
||||
) INHERITS (analytics.events);
|
||||
"""
|
||||
_schema_ = 'analytics'
|
||||
_tablename_ = 'gui_renders'
|
||||
|
||||
eventid = Integer(primary=True)
|
||||
appname = String()
|
||||
ctxid = Integer()
|
||||
guildid = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
cardname = String()
|
||||
duration = Integer()
|
||||
180
src/analytics/events.py
Normal file
180
src/analytics/events.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple, Optional, Generic, Type, TypeVar
|
||||
|
||||
from meta.ipc import AppRoute, AppClient
|
||||
from meta.logger import logging_context, log_wrap
|
||||
|
||||
from data import RowModel
|
||||
from .data import AnalyticsData, CommandStatus, VoiceAction, GuildAction
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
TODO
|
||||
Snapshot type? Incremental or manual?
|
||||
Request snapshot route will require all shards to be online
|
||||
Update batch size before release, or put it in the config
|
||||
"""
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class EventHandler(Generic[T]):
|
||||
def __init__(self, route_name: str, model: Type[RowModel], struct: Type[T], batchsize: int = 20):
|
||||
self.model = model
|
||||
self.struct = struct
|
||||
|
||||
self.batch_size = batchsize
|
||||
|
||||
self.route_name = route_name
|
||||
self._route: Optional[AppRoute] = None
|
||||
self._client: Optional[AppClient] = None
|
||||
|
||||
self.queue: asyncio.Queue[T] = asyncio.Queue()
|
||||
self.batch: list[T] = []
|
||||
self._consumer_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def route(self):
|
||||
if self._route is None:
|
||||
self._route = AppRoute(self.handle_event, name=self.route_name)
|
||||
return self._route
|
||||
|
||||
async def handle_event(self, data):
|
||||
try:
|
||||
await self.queue.put(data)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
|
||||
)
|
||||
|
||||
async def consumer(self):
|
||||
with logging_context(action='consumer'):
|
||||
while True:
|
||||
try:
|
||||
item = await self.queue.get()
|
||||
self.batch.append(item)
|
||||
if len(self.batch) > self.batch_size:
|
||||
await self.process_batch()
|
||||
except asyncio.CancelledError:
|
||||
# Try and process the last batch
|
||||
logger.info(
|
||||
f"Event handler {self.route_name} received cancellation signal! "
|
||||
"Trying to process last batch."
|
||||
)
|
||||
if self.batch:
|
||||
await self.process_batch()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Event handler {self.route_name} received unhandled error."
|
||||
" Ignoring and continuing cautiously."
|
||||
)
|
||||
pass
|
||||
|
||||
async def process_batch(self):
|
||||
with logging_context(action='batch'):
|
||||
logger.debug("Processing Batch")
|
||||
# TODO: copy syntax might be more efficient here
|
||||
await self.model.table.insert_many(
|
||||
self.struct._fields,
|
||||
*map(tuple, self.batch)
|
||||
)
|
||||
self.batch.clear()
|
||||
|
||||
def bind(self, client: AppClient):
|
||||
"""
|
||||
Bind our route to the given client.
|
||||
"""
|
||||
if self._client:
|
||||
raise ValueError("This EventHandler is already attached!")
|
||||
|
||||
self._client = client
|
||||
self.route._client = client
|
||||
client.routes[self.route_name] = self.route
|
||||
return self
|
||||
|
||||
def unbind(self):
|
||||
"""
|
||||
Unbind from the client.
|
||||
"""
|
||||
if not self._client:
|
||||
raise ValueError("Not attached, cannot detach!")
|
||||
self._client.routes.pop(self.route_name, None)
|
||||
self._route = None
|
||||
logger.info(
|
||||
f"EventHandler {self.route_name} has attached to the ShardTalk client."
|
||||
)
|
||||
return self
|
||||
|
||||
async def attach(self, client: AppClient):
|
||||
"""
|
||||
Attach to a ShardTalk client and start listening.
|
||||
"""
|
||||
with logging_context(action=self.route_name):
|
||||
self.bind(client)
|
||||
self._consumer_task = asyncio.create_task(self.consumer())
|
||||
logger.info(
|
||||
f"EventHandler {self.route_name} is listening for incoming events."
|
||||
)
|
||||
return self
|
||||
|
||||
async def detach(self):
|
||||
"""
|
||||
Stop listening and detach from client.
|
||||
"""
|
||||
self.unbind()
|
||||
if self._consumer_task and not self._consumer_task.done():
|
||||
self._consumer_task.cancel()
|
||||
self._consumer_task = None
|
||||
logger.info(
|
||||
f"EventHandler {self.route_name} has detached."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class CommandEvent(NamedTuple):
|
||||
appname: str
|
||||
cmdname: str
|
||||
userid: int
|
||||
created_at: datetime.datetime
|
||||
status: CommandStatus
|
||||
execution_time: float
|
||||
error: Optional[str] = None
|
||||
cogname: Optional[str] = None
|
||||
guildid: Optional[int] = None
|
||||
ctxid: Optional[int] = None
|
||||
|
||||
|
||||
command_event_handler: EventHandler[CommandEvent] = EventHandler(
|
||||
'command_event', AnalyticsData.Commands, CommandEvent, batchsize=1
|
||||
)
|
||||
|
||||
|
||||
class GuildEvent(NamedTuple):
|
||||
appname: str
|
||||
guildid: int
|
||||
action: GuildAction
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
guild_event_handler: EventHandler[GuildEvent] = EventHandler(
|
||||
'guild_event', AnalyticsData.Guilds, GuildEvent, batchsize=0
|
||||
)
|
||||
|
||||
|
||||
class VoiceEvent(NamedTuple):
|
||||
appname: str
|
||||
guildid: int
|
||||
userid: int
|
||||
action: VoiceAction
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
voice_event_handler: EventHandler[VoiceEvent] = EventHandler(
|
||||
'voice_event', AnalyticsData.VoiceSession, VoiceEvent, batchsize=5
|
||||
)
|
||||
128
src/analytics/server.py
Normal file
128
src/analytics/server.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from meta import conf, appname
|
||||
from meta.logger import log_context, log_action_stack, logging_context, log_app, log_wrap
|
||||
from meta.ipc import AppClient
|
||||
from meta.app import appname_from_shard
|
||||
from meta.sharding import shard_count
|
||||
|
||||
from data import Database
|
||||
|
||||
from .events import command_event_handler, guild_event_handler, voice_event_handler
|
||||
from .snapshot import shard_snapshot, ShardSnapshot
|
||||
from .data import AnalyticsData
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
||||
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
||||
|
||||
|
||||
class AnalyticsServer:
|
||||
# TODO: Move these to the config
|
||||
# How often to request snapshots
|
||||
snap_period = 120
|
||||
# How soon after a snapshot failure (e.g. not all shards online) to retry
|
||||
snap_retry_period = 10
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.db = Database(conf.data['args'])
|
||||
self.data = self.db.load_registry(AnalyticsData())
|
||||
|
||||
self.event_handlers = [
|
||||
command_event_handler,
|
||||
guild_event_handler,
|
||||
voice_event_handler
|
||||
]
|
||||
|
||||
self.talk = AppClient(
|
||||
conf.analytics['appname'],
|
||||
appname,
|
||||
{'host': conf.analytics['server_host'], 'port': int(conf.analytics['server_port'])},
|
||||
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
||||
)
|
||||
self.talk_shard_snapshot = self.talk.register_route()(shard_snapshot)
|
||||
|
||||
self._snap_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def attach_event_handlers(self):
|
||||
for handler in self.event_handlers:
|
||||
await handler.attach(self.talk)
|
||||
|
||||
@log_wrap(action='Snap')
|
||||
async def take_snapshot(self):
|
||||
# Check if all the shards are registered on shard_talk
|
||||
expected_peers = [appname_from_shard(i) for i in range(0, shard_count)]
|
||||
if missing := [peer for peer in expected_peers if peer not in self.talk.peers]:
|
||||
# We are missing peer(s)!
|
||||
logger.warning(
|
||||
f"Analytics could not take snapshot because peers are missing: {', '.join(missing)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Everyone is here, ask for shard snapshots
|
||||
results = await self.talk_shard_snapshot().broadcast()
|
||||
|
||||
# Make sure everyone sent results and there were no exceptions (e.g. concurrency)
|
||||
if not all(result is not None and not isinstance(result, Exception) for result in results.values()):
|
||||
# This should essentially never happen
|
||||
# Either some of the shards could not make a snapshot (e.g. Discord client issues)
|
||||
# or they disconnected in the process.
|
||||
logger.warning(
|
||||
f"Analytics could not take snapshot because some peers failed! Partial snapshot: {results}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Now we have a dictionary of shard snapshots, aggregate, pull in remaining data, and store.
|
||||
# TODO Possibly move this out into snapshots.py?
|
||||
aggregate = {field: 0 for field in ShardSnapshot._fields}
|
||||
for result in results.values():
|
||||
for field, num in result._asdict().items():
|
||||
aggregate[field] += num
|
||||
|
||||
row = await self.data.Snapshots.create(
|
||||
appname=appname,
|
||||
guild_count=aggregate['guild_count'],
|
||||
member_count=aggregate['member_count'],
|
||||
user_count=aggregate['user_count'],
|
||||
in_voice=aggregate['voice_count'],
|
||||
)
|
||||
logger.info(f"Created snapshot: {row.data!r}")
|
||||
return True
|
||||
|
||||
@log_wrap(action='SnapLoop')
|
||||
async def snapshot_loop(self):
|
||||
while True:
|
||||
try:
|
||||
result = await self.take_snapshot()
|
||||
if result:
|
||||
await asyncio.sleep(self.snap_period)
|
||||
else:
|
||||
logger.info("Snapshot failed, retrying after %d seconds", self.snap_retry_period)
|
||||
await asyncio.sleep(self.snap_retry_period)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Snapshot loop cancelled, closing.")
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unhandled exception during snapshot loop. Ignoring and continuing cautiously."
|
||||
)
|
||||
await asyncio.sleep(self.snap_retry_period)
|
||||
|
||||
async def run(self):
|
||||
log_action_stack.set(['Analytics'])
|
||||
log_app.set(conf.analytics['appname'])
|
||||
|
||||
async with await self.db.connect():
|
||||
await self.talk.connect()
|
||||
await self.attach_event_handlers()
|
||||
self._snap_task = asyncio.create_task(self.snapshot_loop())
|
||||
await asyncio.gather(*(handler._consumer_task for handler in self.event_handlers))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server = AnalyticsServer()
|
||||
asyncio.run(server.run())
|
||||
28
src/analytics/snapshot.py
Normal file
28
src/analytics/snapshot.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from meta.context import ctx_bot
|
||||
|
||||
|
||||
class ShardSnapshot(NamedTuple):
|
||||
guild_count: int
|
||||
voice_count: int
|
||||
member_count: int
|
||||
user_count: int
|
||||
|
||||
|
||||
async def shard_snapshot():
|
||||
"""
|
||||
Take a snapshot of the current shard.
|
||||
"""
|
||||
bot = ctx_bot.get()
|
||||
if bot is None or not bot.is_ready():
|
||||
# We cannot take a snapshot without Bot
|
||||
# Just quietly fail
|
||||
return None
|
||||
snap = ShardSnapshot(
|
||||
guild_count=len(bot.guilds),
|
||||
voice_count=sum(len(channel.members) for guild in bot.guilds for channel in guild.voice_channels),
|
||||
member_count=sum(len(guild.members) for guild in bot.guilds),
|
||||
user_count=len(set(m.id for guild in bot.guilds for m in guild.members))
|
||||
)
|
||||
return snap
|
||||
6
src/babel/__init__.py
Normal file
6
src/babel/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import BabelCog
|
||||
await bot.add_cog(BabelCog(bot))
|
||||
300
src/babel/cog.py
Normal file
300
src/babel/cog.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Babel Cog.
|
||||
|
||||
Calculates and sets current locale before command runs (via check_once).
|
||||
Also defines the relevant guild and user settings for localisation.
|
||||
"""
|
||||
from typing import Optional
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import UserInputError
|
||||
|
||||
from settings import ModelData
|
||||
from settings.setting_types import StringSetting, BoolSetting
|
||||
from settings.groups import SettingGroup
|
||||
|
||||
from core.data import CoreData
|
||||
|
||||
from .translator import ctx_locale, ctx_translator, LocalBabel, SOURCE_LOCALE
|
||||
|
||||
babel = LocalBabel('babel')
|
||||
_ = babel._
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class LocaleSettings(SettingGroup):
|
||||
class UserLocale(ModelData, StringSetting):
|
||||
"""
|
||||
User-configured locale.
|
||||
|
||||
Exposed via dedicated setting command.
|
||||
"""
|
||||
setting_id = 'user_locale'
|
||||
|
||||
display_name = _p('userset:locale', 'language')
|
||||
desc = _p('userset:locale|desc', "Your preferred language for interacting with me.")
|
||||
|
||||
_model = CoreData.User
|
||||
_column = CoreData.User.locale.name
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
if self.data is None:
|
||||
return t(_p('userset:locale|response', "You have unset your language."))
|
||||
else:
|
||||
return t(_p('userset:locale|response', "You have set your language to `{lang}`.")).format(
|
||||
lang=self.data
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _parse_string(cls, parent_id, string, **kwargs):
|
||||
translator = ctx_translator.get()
|
||||
if string not in translator.supported_locales:
|
||||
lang = string[:20]
|
||||
raise UserInputError(
|
||||
translator.t(
|
||||
_p('userset:locale|error', "Sorry, we do not support the `{lang}` language at this time!")
|
||||
).format(lang=lang)
|
||||
)
|
||||
return string
|
||||
|
||||
class ForceLocale(ModelData, BoolSetting):
|
||||
"""
|
||||
Guild configuration for whether to force usage of the guild locale.
|
||||
|
||||
Exposed via `/configure language` command and standard configuration interface.
|
||||
"""
|
||||
setting_id = 'force_locale'
|
||||
|
||||
display_name = _p('guildset:force_locale', 'force_language')
|
||||
desc = _p('guildset:force_locale|desc',
|
||||
"Whether to force all members to use the configured guild language when interacting with me.")
|
||||
long_desc = _p(
|
||||
'guildset:force_locale|long_desc',
|
||||
"When enabled, commands in this guild will always use the configured guild language, "
|
||||
"regardless of the member's personally configured language."
|
||||
)
|
||||
_outputs = {
|
||||
True: _p('guildset:force_locale|output', 'Enabled (members will be forced to use the server language)'),
|
||||
False: _p('guildset:force_locale|output', 'Disabled (members may set their own language)'),
|
||||
None: 'Not Set' # This should be impossible, since we have a default
|
||||
}
|
||||
_default = False
|
||||
|
||||
_model = CoreData.Guild
|
||||
_column = CoreData.Guild.force_locale.name
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
if self.data:
|
||||
return t(_p(
|
||||
'guildset:force_locale|response',
|
||||
"I will always use the set language in this server."
|
||||
))
|
||||
else:
|
||||
return t(_p(
|
||||
'guildset:force_locale|response',
|
||||
"I will now allow the members to set their own language here."
|
||||
))
|
||||
|
||||
class GuildLocale(ModelData, StringSetting):
|
||||
"""
|
||||
Guild-configured locale.
|
||||
|
||||
Exposed via `/configure language` command, and standard configuration interface.
|
||||
"""
|
||||
setting_id = 'guild_locale'
|
||||
|
||||
display_name = _p('guildset:locale', 'language')
|
||||
desc = _p('guildset:locale|desc', "Your preferred language for interacting with me.")
|
||||
|
||||
_model = CoreData.Guild
|
||||
_column = CoreData.Guild.locale.name
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
if self.data is None:
|
||||
return t(_p('guildset:locale|response', "You have reset the guild language."))
|
||||
else:
|
||||
return t(_p('guildset:locale|response', "You have set the guild language to `{lang}`.")).format(
|
||||
lang=self.data
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _parse_string(cls, parent_id, string, **kwargs):
|
||||
translator = ctx_translator.get()
|
||||
if string not in translator.supported_locales:
|
||||
lang = string[:20]
|
||||
raise UserInputError(
|
||||
translator.t(
|
||||
_p('guildset:locale|error', "Sorry, we do not support the `{lang}` language at this time!")
|
||||
).format(lang=lang)
|
||||
)
|
||||
return string
|
||||
|
||||
|
||||
class BabelCog(LionCog):
|
||||
depends = {'CoreCog'}
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.settings = LocaleSettings()
|
||||
self.t = self.bot.translator.t
|
||||
|
||||
async def cog_load(self):
|
||||
if not self.bot.core:
|
||||
raise ValueError("CoreCog must be loaded first!")
|
||||
self.bot.core.guild_settings.attach(LocaleSettings.ForceLocale)
|
||||
self.bot.core.guild_settings.attach(LocaleSettings.GuildLocale)
|
||||
self.bot.core.user_settings.attach(LocaleSettings.UserLocale)
|
||||
|
||||
async def cog_unload(self):
|
||||
pass
|
||||
|
||||
async def get_user_locale(self, userid):
|
||||
"""
|
||||
Fetch the best locale we can guess for this userid.
|
||||
"""
|
||||
data = await self.bot.core.data.User.fetch(userid)
|
||||
if data:
|
||||
return data.locale or data.locale_hint or SOURCE_LOCALE
|
||||
else:
|
||||
return SOURCE_LOCALE
|
||||
|
||||
async def bot_check_once(self, ctx: LionContext): # type: ignore # Type checker doesn't understand coro checks
|
||||
"""
|
||||
Calculate and inject the current locale before the command begins.
|
||||
|
||||
Locale resolution is calculated as follows:
|
||||
If the guild has force_locale enabled, and a locale set,
|
||||
then the guild's locale will be used.
|
||||
|
||||
Otherwise, the priority is
|
||||
user_locale -> command_locale -> user_locale_hint -> guild_locale -> default_locale
|
||||
"""
|
||||
locale = None
|
||||
if ctx.guild:
|
||||
forced = ctx.alion.guild_settings['force_locale'].value
|
||||
guild_locale = ctx.alion.guild_settings['guild_locale'].value
|
||||
if forced:
|
||||
locale = guild_locale
|
||||
|
||||
locale = locale or ctx.alion.user_settings['user_locale'].value
|
||||
if ctx.interaction:
|
||||
locale = locale or ctx.interaction.locale.value
|
||||
if ctx.guild:
|
||||
locale = locale or guild_locale
|
||||
|
||||
locale = locale or SOURCE_LOCALE
|
||||
|
||||
ctx_locale.set(locale)
|
||||
ctx_translator.set(self.bot.translator)
|
||||
return True
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name=LocaleSettings.UserLocale.display_name,
|
||||
description=LocaleSettings.UserLocale.desc
|
||||
)
|
||||
async def cmd_language(self, ctx: LionContext, language: str):
|
||||
"""
|
||||
Dedicated user setting command for the `locale` setting.
|
||||
"""
|
||||
if not ctx.interaction:
|
||||
# This command is not available as a text command
|
||||
return
|
||||
|
||||
setting = await self.settings.UserLocale.get(ctx.author.id)
|
||||
new_data = await setting._parse_string(ctx.author.id, language)
|
||||
await setting.interactive_set(new_data, ctx.interaction)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name=_p('cmd:configure_language', "configure_language"),
|
||||
description=_p('cmd:configure_language|desc',
|
||||
"Configure the default language I will use in this server.")
|
||||
)
|
||||
@appcmds.choices(
|
||||
force_language=[
|
||||
appcmds.Choice(name=LocaleSettings.ForceLocale._outputs[True], value=1),
|
||||
appcmds.Choice(name=LocaleSettings.ForceLocale._outputs[False], value=0),
|
||||
]
|
||||
)
|
||||
@appcmds.guild_only() # Can be removed when attached as a subcommand
|
||||
async def cmd_configure_language(
|
||||
self, ctx: LionContext, language: Optional[str] = None, force_language: Optional[appcmds.Choice[int]] = None
|
||||
):
|
||||
if not ctx.interaction:
|
||||
# This command is not available as a text command
|
||||
return
|
||||
if not ctx.guild:
|
||||
# This is impossible by decorators, but adding this guard for the type checker
|
||||
return
|
||||
t = self.t
|
||||
# TODO: Setting group, and group setting widget
|
||||
# We can attach the command to the setting group as an application command
|
||||
# Then load it into the configure command group dynamically
|
||||
|
||||
lang_setting = await self.settings.GuildLocale.get(ctx.guild.id)
|
||||
force_setting = await self.settings.ForceLocale.get(ctx.guild.id)
|
||||
|
||||
if language:
|
||||
lang_data = await lang_setting._parse_string(ctx.guild.id, language)
|
||||
if force_language is not None:
|
||||
force_data = bool(force_language)
|
||||
|
||||
if force_language is not None and not (lang_data if language is not None else lang_setting.value):
|
||||
# Setting force without having a language!
|
||||
raise UserInputError(
|
||||
t(_p(
|
||||
'cmd:configure_language|error',
|
||||
"You cannot enable `{force_setting}` without having a configured language!"
|
||||
)).format(force_setting=t(LocaleSettings.ForceLocale.display_name))
|
||||
)
|
||||
# TODO: Really need simultaneous model writes, or batched writes
|
||||
lines = []
|
||||
if language:
|
||||
lang_setting.data = lang_data
|
||||
await lang_setting.write()
|
||||
lines.append(lang_setting.update_message)
|
||||
if force_language is not None:
|
||||
force_setting.data = force_data
|
||||
await force_setting.write()
|
||||
lines.append(force_setting.update_message)
|
||||
result = '\n'.join(
|
||||
f"{self.bot.config.emojis.tick} {line}" for line in lines
|
||||
)
|
||||
# TODO: Setting group widget
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
title=t(_p('cmd:configure_language|success', "Language settings updated!")),
|
||||
description=result
|
||||
)
|
||||
)
|
||||
|
||||
@cmd_configure_language.autocomplete('language')
|
||||
async def cmd_configure_language_acmpl_language(self, interaction: discord.Interaction, partial: str):
|
||||
# TODO: More friendly language names
|
||||
supported = self.bot.translator.supported_locales
|
||||
matching = [lang for lang in supported if partial.lower() in lang]
|
||||
t = self.t
|
||||
if not matching:
|
||||
return [
|
||||
appcmds.Choice(
|
||||
name=t(_p(
|
||||
'cmd:configure_language|acmpl:language',
|
||||
"No supported languages matching {partial}"
|
||||
)).format(partial=partial),
|
||||
value='None'
|
||||
)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
appcmds.Choice(name=lang, value=lang)
|
||||
for lang in matching
|
||||
]
|
||||
34
src/babel/enums.py
Normal file
34
src/babel/enums.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LocaleMap(Enum):
|
||||
american_english = 'en-US'
|
||||
british_english = 'en-GB'
|
||||
bulgarian = 'bg'
|
||||
chinese = 'zh-CN'
|
||||
taiwan_chinese = 'zh-TW'
|
||||
croatian = 'hr'
|
||||
czech = 'cs'
|
||||
danish = 'da'
|
||||
dutch = 'nl'
|
||||
finnish = 'fi'
|
||||
french = 'fr'
|
||||
german = 'de'
|
||||
greek = 'el'
|
||||
hindi = 'hi'
|
||||
hungarian = 'hu'
|
||||
italian = 'it'
|
||||
japanese = 'ja'
|
||||
korean = 'ko'
|
||||
lithuanian = 'lt'
|
||||
norwegian = 'no'
|
||||
polish = 'pl'
|
||||
brazil_portuguese = 'pt-BR'
|
||||
romanian = 'ro'
|
||||
russian = 'ru'
|
||||
spain_spanish = 'es-ES'
|
||||
swedish = 'sv-SE'
|
||||
thai = 'th'
|
||||
turkish = 'tr'
|
||||
ukrainian = 'uk'
|
||||
vietnamese = 'vi'
|
||||
157
src/babel/translator.py
Normal file
157
src/babel/translator.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import gettext
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
|
||||
from discord.app_commands import Translator, locale_str
|
||||
from discord.enums import Locale
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SOURCE_LOCALE = 'en_uk'
|
||||
ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE)
|
||||
ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore
|
||||
|
||||
null = gettext.NullTranslations()
|
||||
|
||||
|
||||
class LeoBabel(Translator):
|
||||
def __init__(self):
|
||||
self.supported_locales = {loc.name for loc in Locale}
|
||||
self.supported_domains = {}
|
||||
self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator
|
||||
|
||||
def read_supported(self):
|
||||
"""
|
||||
Load supported localisations and domains from the config.
|
||||
"""
|
||||
from meta import conf
|
||||
|
||||
locales = conf.babel.get('locales', '')
|
||||
stripped = (loc.strip(', ') for loc in locales.split(','))
|
||||
self.supported_locales = {loc for loc in stripped if loc}
|
||||
|
||||
domains = conf.babel.get('domains', '')
|
||||
stripped = (dom.strip(', ') for dom in domains.split(','))
|
||||
self.supported_domains = {dom for dom in stripped if dom}
|
||||
|
||||
async def load(self):
|
||||
"""
|
||||
Initialise the gettext translators for the supported_locales.
|
||||
"""
|
||||
self.read_supported()
|
||||
for locale in self.supported_locales:
|
||||
for domain in self.supported_domains:
|
||||
if locale == SOURCE_LOCALE:
|
||||
continue
|
||||
try:
|
||||
translator = gettext.translation(domain, "locales/", languages=[locale])
|
||||
except OSError:
|
||||
# Presume translation does not exist
|
||||
logger.warning(f"Could not load translator for supported <locale: {locale}> <domain: {domain}>")
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"Loaded translator for <locale: {locale}> <domain: {domain}>")
|
||||
self.translators[locale][domain] = translator
|
||||
|
||||
async def unload(self):
|
||||
self.translators.clear()
|
||||
|
||||
def get_translator(self, locale, domain):
|
||||
if locale == SOURCE_LOCALE:
|
||||
return null
|
||||
|
||||
translator = self.translators[locale].get(domain, None)
|
||||
if translator is None:
|
||||
logger.warning(
|
||||
f"Translator missing for requested <locale: {locale}> and <domain: {domain}>. Setting NullTranslator."
|
||||
)
|
||||
self.translators[locale][domain] = null
|
||||
translator = null
|
||||
return translator
|
||||
|
||||
def t(self, lazystr, locale=None):
|
||||
domain = lazystr.domain
|
||||
translator = self.get_translator(locale or lazystr.locale, domain)
|
||||
return lazystr._translate_with(translator)
|
||||
|
||||
async def translate(self, string: locale_str, locale: Locale, context):
|
||||
if locale.value in self.supported_locales:
|
||||
domain = string.extras.get('domain', None)
|
||||
if domain is None and isinstance(string, LazyStr):
|
||||
logger.debug(
|
||||
f"LeoBabel cannot translate a locale_str with no domain set. Context: {context}, String: {string}"
|
||||
)
|
||||
return None
|
||||
|
||||
translator = self.get_translator(locale.value, domain)
|
||||
if not isinstance(string, LazyStr):
|
||||
lazy = LazyStr(Method.GETTEXT, string.message)
|
||||
else:
|
||||
lazy = string
|
||||
return lazy._translate_with(translator)
|
||||
|
||||
|
||||
class Method(Enum):
|
||||
GETTEXT = 'gettext'
|
||||
NGETTEXT = 'ngettext'
|
||||
PGETTEXT = 'pgettext'
|
||||
NPGETTEXT = 'npgettext'
|
||||
|
||||
|
||||
class LocalBabel:
|
||||
def __init__(self, domain):
|
||||
self.domain = domain
|
||||
|
||||
@property
|
||||
def methods(self):
|
||||
return (self._, self._n, self._p, self._np)
|
||||
|
||||
def _(self, message):
|
||||
return LazyStr(Method.GETTEXT, message, domain=self.domain)
|
||||
|
||||
def _n(self, singular, plural, n):
|
||||
return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain)
|
||||
|
||||
def _p(self, context, message):
|
||||
return LazyStr(Method.PGETTEXT, context, message, domain=self.domain)
|
||||
|
||||
def _np(self, context, singular, plural, n):
|
||||
return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain)
|
||||
|
||||
|
||||
class LazyStr(locale_str):
|
||||
__slots__ = ('method', 'args', 'domain', 'locale')
|
||||
|
||||
def __init__(self, method, *args, locale=None, domain=None):
|
||||
self.method = method
|
||||
self.args = args
|
||||
self.domain = domain
|
||||
self.locale = locale or ctx_locale.get()
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._translate_with(null)
|
||||
|
||||
@property
|
||||
def extras(self):
|
||||
return {'locale': self.locale, 'domain': self.domain}
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def _translate_with(self, translator: gettext.GNUTranslations):
|
||||
method = getattr(translator, self.method.value)
|
||||
return method(*self.args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})'
|
||||
|
||||
def __eq__(self, obj: object) -> bool:
|
||||
return isinstance(obj, locale_str) and self.message == obj.message
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.args)
|
||||
87
src/bot.py
Normal file
87
src/bot.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from meta import LionBot, conf, sharding, appname, shard_talk
|
||||
from meta.app import shardname
|
||||
from meta.logger import log_context, log_action_stack, logging_context
|
||||
from meta.context import ctx_bot
|
||||
|
||||
from data import Database
|
||||
|
||||
from babel.translator import LeoBabel, ctx_translator
|
||||
|
||||
from constants import DATA_VERSION
|
||||
|
||||
|
||||
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
||||
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
db = Database(conf.data['args'])
|
||||
|
||||
|
||||
async def main():
|
||||
log_action_stack.set(["Initialising"])
|
||||
logger.info("Initialising StudyLion")
|
||||
|
||||
intents = discord.Intents.all()
|
||||
intents.members = True
|
||||
intents.message_content = True
|
||||
|
||||
async with await db.connect():
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||
logger.critical(error)
|
||||
raise RuntimeError(error)
|
||||
|
||||
translator = LeoBabel()
|
||||
ctx_translator.set(translator)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with LionBot(
|
||||
command_prefix=commands.when_mentioned,
|
||||
intents=intents,
|
||||
appname=appname,
|
||||
shardname=shardname,
|
||||
db=db,
|
||||
config=conf,
|
||||
initial_extensions=['utils', 'core', 'analytics', 'babel', 'modules'],
|
||||
web_client=session,
|
||||
app_ipc=shard_talk,
|
||||
testing_guilds=conf.bot.getintlist('admin_guilds'),
|
||||
shard_id=sharding.shard_number,
|
||||
shard_count=sharding.shard_count,
|
||||
translator=translator
|
||||
) as lionbot:
|
||||
ctx_bot.set(lionbot)
|
||||
try:
|
||||
with logging_context(context=f"APP: {appname}"):
|
||||
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
||||
await lionbot.start(conf.bot['TOKEN'])
|
||||
except asyncio.CancelledError:
|
||||
with logging_context(context=f"APP: {appname}", action="Shutting Down"):
|
||||
logger.info("StudyLion closed, shutting down.", exc_info=True)
|
||||
|
||||
|
||||
def _main():
|
||||
from signal import SIGINT, SIGTERM
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
main_task = asyncio.ensure_future(main())
|
||||
for signal in [SIGINT, SIGTERM]:
|
||||
loop.add_signal_handler(signal, main_task.cancel)
|
||||
try:
|
||||
loop.run_until_complete(main_task)
|
||||
finally:
|
||||
loop.close()
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_main()
|
||||
4
src/constants.py
Normal file
4
src/constants.py
Normal file
@@ -0,0 +1,4 @@
|
||||
CONFIG_FILE = "config/bot.conf"
|
||||
DATA_VERSION = 13
|
||||
|
||||
MAX_COINS = 2147483647 - 1
|
||||
5
src/core/__init__.py
Normal file
5
src/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .cog import CoreCog
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(CoreCog(bot))
|
||||
89
src/core/cog.py
Normal file
89
src/core/cog.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.app import shardname, appname
|
||||
from meta.logger import log_wrap
|
||||
from utils.lib import utc_now
|
||||
|
||||
from settings.groups import SettingGroup
|
||||
|
||||
from .data import CoreData
|
||||
from .lion import Lions
|
||||
from .guild_settings import GuildSettings
|
||||
from .user_settings import UserSettings
|
||||
|
||||
|
||||
class CoreCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = CoreData()
|
||||
bot.db.load_registry(self.data)
|
||||
self.lions = Lions(bot)
|
||||
|
||||
self.app_config: Optional[CoreData.AppConfig] = None
|
||||
self.bot_config: Optional[CoreData.BotConfig] = None
|
||||
self.shard_data: Optional[CoreData.Shard] = None
|
||||
|
||||
# Some global setting registries
|
||||
# Do not use these for direct setting access
|
||||
# Instead, import the setting directly or use the cog API
|
||||
self.bot_setting_groups: list[SettingGroup] = []
|
||||
self.guild_setting_groups: list[SettingGroup] = []
|
||||
self.user_setting_groups: list[SettingGroup] = []
|
||||
|
||||
# Some ModelSetting registries
|
||||
# These are for more convenient direct access
|
||||
self.guild_settings = GuildSettings
|
||||
self.user_settings = UserSettings
|
||||
|
||||
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
|
||||
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
|
||||
|
||||
async def bot_check_once(self, ctx: LionContext): # type: ignore
|
||||
lion = await self.lions.fetch(ctx.guild.id if ctx.guild else 0, ctx.author.id)
|
||||
if ctx.guild:
|
||||
await lion.touch_discord_models(ctx.author) # type: ignore # Type checker doesn't recognise guard
|
||||
ctx.alion = lion
|
||||
return True
|
||||
|
||||
async def cog_load(self):
|
||||
# Fetch (and possibly create) core data rows.
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||
self.shard_data = await self.data.Shard.fetch_or_create(
|
||||
shardname,
|
||||
appname=appname,
|
||||
shard_id=self.bot.shard_id,
|
||||
shard_count=self.bot.shard_count
|
||||
)
|
||||
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
|
||||
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')
|
||||
|
||||
self.bot.core = self
|
||||
await self.bot.add_cog(self.lions)
|
||||
|
||||
# Load the app command cache
|
||||
for guildid in self.bot.testing_guilds:
|
||||
self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid))
|
||||
self.app_cmd_cache += await self.bot.tree.fetch_commands()
|
||||
self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache}
|
||||
|
||||
async def cog_unload(self):
|
||||
await self.bot.remove_cog(self.lions.qualified_name)
|
||||
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_join')
|
||||
self.bot.remove_listener(self.shard_update_guilds, name='on_guild_leave')
|
||||
self.bot.core = None
|
||||
|
||||
@LionCog.listener('on_ready')
|
||||
@log_wrap(action='Touch shard data')
|
||||
async def touch_shard_data(self):
|
||||
# Update the last login and guild count for this shard
|
||||
await self.shard_data.update(last_login=utc_now(), guild_count=len(self.bot.guilds))
|
||||
|
||||
@log_wrap(action='Update shard guilds')
|
||||
async def shard_update_guilds(self, guild):
|
||||
await self.shard_data.update(guild_count=len(self.bot.guilds))
|
||||
315
src/core/data.py
Normal file
315
src/core/data.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
from cachetools import TTLCache
|
||||
|
||||
from data import Table, Registry, Column, RowModel
|
||||
from data.models import WeakCache
|
||||
from data.columns import Integer, String, Bool, Timestamp
|
||||
|
||||
|
||||
class CoreData(Registry, name="core"):
|
||||
class AppConfig(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE app_config(
|
||||
appname TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'app_config'
|
||||
|
||||
appname = String(primary=True)
|
||||
created_at = Timestamp()
|
||||
|
||||
class BotConfig(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE bot_config(
|
||||
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
|
||||
default_skin TEXT
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'bot_config'
|
||||
|
||||
appname = String(primary=True)
|
||||
default_skin = String()
|
||||
|
||||
class Shard(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE shard_data(
|
||||
shardname TEXT PRIMARY KEY,
|
||||
appname TEXT REFERENCES bot_config(appname) ON DELETE CASCADE,
|
||||
shard_id INTEGER NOT NULL,
|
||||
shard_count INTEGER NOT NULL,
|
||||
last_login TIMESTAMPTZ,
|
||||
guild_count INTEGER
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'shard_data'
|
||||
|
||||
shardname = String(primary=True)
|
||||
appname = String()
|
||||
shard_id = Integer()
|
||||
shard_count = Integer()
|
||||
last_login = Timestamp()
|
||||
guild_count = Integer()
|
||||
|
||||
class User(RowModel):
|
||||
"""
|
||||
User model, representing configuration data for a single user.
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE user_config(
|
||||
userid BIGINT PRIMARY KEY,
|
||||
timezone TEXT,
|
||||
topgg_vote_reminder BOOLEAN,
|
||||
avatar_hash TEXT,
|
||||
name TEXT,
|
||||
API_timestamp BIGINT,
|
||||
gems INTEGER DEFAULT 0,
|
||||
first_seen TIMESTAMPTZ DEFAULT now(),
|
||||
last_seen TIMESTAMPTZ,
|
||||
locale TEXT,
|
||||
locale_hint TEXT
|
||||
);
|
||||
"""
|
||||
|
||||
_tablename_ = "user_config"
|
||||
_cache_: WeakCache[tuple[int], 'CoreData.User'] = WeakCache(TTLCache(1000, ttl=60*5))
|
||||
|
||||
userid = Integer(primary=True)
|
||||
timezone = String()
|
||||
topgg_vote_reminder = Bool()
|
||||
avatar_hash = String()
|
||||
name = String()
|
||||
API_timestamp = Integer()
|
||||
gems = Integer()
|
||||
first_seen = Timestamp()
|
||||
last_seen = Timestamp()
|
||||
locale = String()
|
||||
locale_hint = String()
|
||||
|
||||
class Guild(RowModel):
|
||||
"""
|
||||
Guild model, representing configuration data for a single guild.
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE guild_config(
|
||||
guildid BIGINT PRIMARY KEY,
|
||||
admin_role BIGINT,
|
||||
mod_role BIGINT,
|
||||
event_log_channel BIGINT,
|
||||
mod_log_channel BIGINT,
|
||||
alert_channel BIGINT,
|
||||
studyban_role BIGINT,
|
||||
min_workout_length INTEGER,
|
||||
workout_reward INTEGER,
|
||||
max_tasks INTEGER,
|
||||
task_reward INTEGER,
|
||||
task_reward_limit INTEGER,
|
||||
study_hourly_reward INTEGER,
|
||||
study_hourly_live_bonus INTEGER,
|
||||
renting_price INTEGER,
|
||||
renting_category BIGINT,
|
||||
renting_cap INTEGER,
|
||||
renting_role BIGINT,
|
||||
renting_sync_perms BOOLEAN,
|
||||
accountability_category BIGINT,
|
||||
accountability_lobby BIGINT,
|
||||
accountability_bonus INTEGER,
|
||||
accountability_reward INTEGER,
|
||||
accountability_price INTEGER,
|
||||
video_studyban BOOLEAN,
|
||||
video_grace_period INTEGER,
|
||||
greeting_channel BIGINT,
|
||||
greeting_message TEXT,
|
||||
returning_message TEXT,
|
||||
starting_funds INTEGER,
|
||||
persist_roles BOOLEAN,
|
||||
daily_study_cap INTEGER,
|
||||
pomodoro_channel BIGINT,
|
||||
name TEXT,
|
||||
first_joined_at TIMESTAMPTZ DEFAULT now(),
|
||||
left_at TIMESTAMPTZ,
|
||||
locale TEXT,
|
||||
force_locale BOOLEAN
|
||||
);
|
||||
|
||||
"""
|
||||
|
||||
_tablename_ = "guild_config"
|
||||
_cache_: WeakCache[tuple[int], 'CoreData.Guild'] = WeakCache(TTLCache(1000, ttl=60*5))
|
||||
|
||||
guildid = Integer(primary=True)
|
||||
|
||||
admin_role = Integer()
|
||||
mod_role = Integer()
|
||||
event_log_channel = Integer()
|
||||
mod_log_channel = Integer()
|
||||
alert_channel = Integer()
|
||||
|
||||
studyban_role = Integer()
|
||||
max_study_bans = Integer()
|
||||
|
||||
min_workout_length = Integer()
|
||||
workout_reward = Integer()
|
||||
|
||||
max_tasks = Integer()
|
||||
task_reward = Integer()
|
||||
task_reward_limit = Integer()
|
||||
|
||||
study_hourly_reward = Integer()
|
||||
study_hourly_live_bonus = Integer()
|
||||
daily_study_cap = Integer()
|
||||
|
||||
renting_price = Integer()
|
||||
renting_category = Integer()
|
||||
renting_cap = Integer()
|
||||
renting_role = Integer()
|
||||
renting_sync_perms = Bool()
|
||||
|
||||
accountability_category = Integer()
|
||||
accountability_lobby = Integer()
|
||||
accountability_bonus = Integer()
|
||||
accountability_reward = Integer()
|
||||
accountability_price = Integer()
|
||||
|
||||
video_studyban = Bool()
|
||||
video_grace_period = Integer()
|
||||
|
||||
greeting_channel = Integer()
|
||||
greeting_message = String()
|
||||
returning_message = String()
|
||||
|
||||
starting_funds = Integer()
|
||||
persist_roles = Bool()
|
||||
|
||||
pomodoro_channel = Integer()
|
||||
|
||||
name = String()
|
||||
|
||||
first_joined_at = Timestamp()
|
||||
left_at = Timestamp()
|
||||
|
||||
locale = String()
|
||||
force_locale = Bool()
|
||||
|
||||
unranked_rows = Table('unranked_rows')
|
||||
|
||||
donator_roles = Table('donator_roles')
|
||||
|
||||
member_ranks = Table('member_ranks')
|
||||
|
||||
class Member(RowModel):
|
||||
"""
|
||||
Member model, representing configuration data for a single member.
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE members(
|
||||
guildid BIGINT,
|
||||
userid BIGINT,
|
||||
tracked_time INTEGER DEFAULT 0,
|
||||
coins INTEGER DEFAULT 0,
|
||||
workout_count INTEGER DEFAULT 0,
|
||||
revision_mute_count INTEGER DEFAULT 0,
|
||||
last_workout_start TIMESTAMP,
|
||||
last_study_badgeid INTEGER REFERENCES study_badges ON DELETE SET NULL,
|
||||
video_warned BOOLEAN DEFAULT FALSE,
|
||||
display_name TEXT,
|
||||
first_joined TIMESTAMPTZ DEFAULT now(),
|
||||
last_left TIMESTAMPTZ,
|
||||
_timestamp TIMESTAMP DEFAULT (now() at time zone 'utc'),
|
||||
PRIMARY KEY(guildid, userid)
|
||||
);
|
||||
CREATE INDEX member_timestamps ON members (_timestamp);
|
||||
"""
|
||||
_tablename_ = 'members'
|
||||
_cache_: WeakCache[tuple[int, int], 'CoreData.Member'] = WeakCache(TTLCache(5000, ttl=60*5))
|
||||
|
||||
guildid = Integer(primary=True)
|
||||
userid = Integer(primary=True)
|
||||
|
||||
tracked_time = Integer()
|
||||
coins = Integer()
|
||||
|
||||
workout_count = Integer()
|
||||
revision_mute_count = Integer()
|
||||
last_workout_start = Timestamp()
|
||||
last_study_badgeid = Integer()
|
||||
video_warned = Bool()
|
||||
display_name = String()
|
||||
|
||||
first_joined = Timestamp()
|
||||
last_left = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
|
||||
"""
|
||||
Safely add pending coins to a list of members.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
pending:
|
||||
List of tuples of the form `(guildid, userid, pending_coins)`.
|
||||
"""
|
||||
query = sql.SQL("""
|
||||
UPDATE members
|
||||
SET
|
||||
coins = LEAST(coins + t.coin_diff, 2147483647)
|
||||
FROM
|
||||
(VALUES {})
|
||||
AS
|
||||
t (guildid, userid, coin_diff)
|
||||
WHERE
|
||||
members.guildid = t.guildid
|
||||
AND
|
||||
members.userid = t.userid
|
||||
RETURNING *
|
||||
""").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {}, {})").format(sql.Placeholder(), sql.Placeholder(), sql.Placeholder())
|
||||
for _ in pending
|
||||
)
|
||||
)
|
||||
# TODO: Replace with copy syntax/query?
|
||||
conn = await cls.table.connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
tuple(chain(*pending))
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return cls._make_rows(*rows)
|
||||
|
||||
@classmethod
|
||||
async def get_member_rank(cls, guildid, userid, untracked):
|
||||
"""
|
||||
Get the time and coin ranking for the given member, ignoring the provided untracked members.
|
||||
"""
|
||||
conn = await cls.table.connector.get_connection()
|
||||
async with conn.cursor() as curs:
|
||||
await curs.execute(
|
||||
"""
|
||||
SELECT
|
||||
time_rank, coin_rank
|
||||
FROM (
|
||||
SELECT
|
||||
userid,
|
||||
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
|
||||
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
|
||||
FROM members_totals
|
||||
WHERE
|
||||
guildid=%s AND userid NOT IN %s
|
||||
) AS guild_ranks WHERE userid=%s
|
||||
""",
|
||||
(guildid, tuple(untracked), userid)
|
||||
)
|
||||
return (await curs.fetchone()) or (None, None)
|
||||
7
src/core/guild_settings.py
Normal file
7
src/core/guild_settings.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from settings.groups import ModelSettings, SettingDotDict
|
||||
from .data import CoreData
|
||||
|
||||
|
||||
class GuildSettings(ModelSettings):
|
||||
_settings = SettingDotDict()
|
||||
model = CoreData.Guild
|
||||
153
src/core/lion.py
Normal file
153
src/core/lion.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Optional
|
||||
from cachetools import LRUCache
|
||||
import discord
|
||||
|
||||
from meta import LionCog, LionBot, LionContext
|
||||
from settings import InteractiveSetting
|
||||
from utils.lib import utc_now
|
||||
from data import WeakCache
|
||||
|
||||
from .data import CoreData
|
||||
|
||||
from .user_settings import UserSettings
|
||||
from .guild_settings import GuildSettings
|
||||
|
||||
|
||||
class Lion:
|
||||
"""
|
||||
A Lion is a high level representation of a Member in the LionBot paradigm.
|
||||
|
||||
All members interacted with by the application should be available as Lions.
|
||||
It primarily provides an interface to the User and Member data.
|
||||
Lion also provides centralised access to various Member properties and methods,
|
||||
that would normally be served by other cogs.
|
||||
|
||||
Many Lion methods may only be used when the required cogs and extensions are loaded.
|
||||
A Lion may exist without a Bot instance or a Member in cache,
|
||||
although the functionality available will be more limited.
|
||||
|
||||
There is no guarantee that a corresponding discord Member actually exists.
|
||||
"""
|
||||
__slots__ = ('bot', 'data', 'user_data', 'guild_data', '_member', '__weakref__')
|
||||
|
||||
def __init__(self, bot: LionBot, data: CoreData.Member, user_data: CoreData.User, guild_data: CoreData.Guild):
|
||||
self.bot = bot
|
||||
self.data = data
|
||||
self.user_data = user_data
|
||||
self.guild_data = guild_data
|
||||
|
||||
self._member: Optional[discord.Member] = None
|
||||
|
||||
# Data properties
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return (self.data.guildid, self.data.userid)
|
||||
|
||||
@property
|
||||
def guildid(self):
|
||||
return self.data.guildid
|
||||
|
||||
@property
|
||||
def userid(self):
|
||||
return self.data.userid
|
||||
|
||||
@classmethod
|
||||
def get(cls, guildid, userid):
|
||||
return cls._cache_.get((guildid, userid), None)
|
||||
|
||||
# ModelSettings interfaces
|
||||
@property
|
||||
def guild_settings(self):
|
||||
return GuildSettings(self.guildid, self.guild_data, bot=self.bot)
|
||||
|
||||
@property
|
||||
def user_settings(self):
|
||||
return UserSettings(self.userid, self.user_data, bot=self.bot)
|
||||
|
||||
# Setting interfaces
|
||||
# Each of these return an initialised member setting
|
||||
|
||||
@property
|
||||
def timezone(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def locale(self):
|
||||
pass
|
||||
|
||||
# Time utilities
|
||||
@property
|
||||
def now(self):
|
||||
"""
|
||||
Returns current time-zone aware time for the member.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Discord data cache
|
||||
async def touch_discord_models(self, member: discord.Member):
|
||||
"""
|
||||
Update the stored discord data from the given user or member object.
|
||||
Intended to be used when we get member data from events that may not be available in cache.
|
||||
"""
|
||||
# Can we do these in one query?
|
||||
if member.guild and (self.guild_data.name != member.guild.name):
|
||||
await self.guild_data.update(name=member.guild.name)
|
||||
|
||||
avatar_key = member.avatar.key if member.avatar else None
|
||||
await self.user_data.update(avatar_hash=avatar_key, name=member.name, last_seen=utc_now())
|
||||
|
||||
if member.display_name != self.data.display_name:
|
||||
await self.data.update(display_name=member.display_name)
|
||||
|
||||
async def get_member(self) -> Optional[discord.Member]:
|
||||
"""
|
||||
Retrieve the member object for this Lion, if possible.
|
||||
|
||||
If the guild or member cannot be retrieved, returns None.
|
||||
"""
|
||||
guild = self.bot.get_guild(self.guildid)
|
||||
if guild is not None:
|
||||
member = guild.get_member(self.userid)
|
||||
if member is None:
|
||||
try:
|
||||
member = await guild.fetch_member(self.userid)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return member
|
||||
|
||||
|
||||
class Lions(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
# Full Lions cache
|
||||
# Don't expire Lions with strong references
|
||||
self._cache_: WeakCache[tuple[int, int], 'Lion'] = WeakCache(LRUCache(5000))
|
||||
|
||||
self._settings_: dict[str, InteractiveSetting] = {}
|
||||
|
||||
async def fetch(self, guildid, userid) -> Lion:
|
||||
"""
|
||||
Fetch or create the given Member.
|
||||
If the guild or user row doesn't exist, also creates it.
|
||||
Relies on the core cog existing, to retrieve the core data.
|
||||
"""
|
||||
# TODO: Find a way to reduce this to one query, while preserving cache
|
||||
lion = self._cache_.get((guildid, userid))
|
||||
if lion is None:
|
||||
if self.bot.core:
|
||||
data = self.bot.core.data
|
||||
else:
|
||||
raise ValueError("Cannot fetch Lion before core module is attached.")
|
||||
|
||||
guild = await data.Guild.fetch_or_create(guildid)
|
||||
user = await data.User.fetch_or_create(userid)
|
||||
member = await data.Member.fetch_or_create(guildid, userid)
|
||||
lion = Lion(self.bot, member, user, guild)
|
||||
self._cache_[(guildid, userid)] = lion
|
||||
return lion
|
||||
|
||||
def add_model_setting(self, setting: InteractiveSetting):
|
||||
self._settings_[setting.__class__.__name__] = setting
|
||||
return setting
|
||||
7
src/core/user_settings.py
Normal file
7
src/core/user_settings.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from settings.groups import ModelSettings, SettingDotDict
|
||||
from .data import CoreData
|
||||
|
||||
|
||||
class UserSettings(ModelSettings):
|
||||
_settings = SettingDotDict()
|
||||
model = CoreData.User
|
||||
9
src/data/__init__.py
Normal file
9
src/data/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .conditions import Condition, condition, NULL
|
||||
from .database import Database
|
||||
from .models import RowModel, RowTable, WeakCache
|
||||
from .table import Table
|
||||
from .base import Expression, RawExpr
|
||||
from .columns import ColumnExpr, Column, Integer, String
|
||||
from .registry import Registry, AttachableClass, Attachable
|
||||
from .adapted import RegisterEnum
|
||||
from .queries import ORDER, NULLS, JOINTYPE
|
||||
32
src/data/adapted.py
Normal file
32
src/data/adapted.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# from enum import Enum
|
||||
from typing import Optional
|
||||
from psycopg.types.enum import register_enum, EnumInfo
|
||||
from .registry import Attachable, Registry
|
||||
|
||||
|
||||
class RegisterEnum(Attachable):
|
||||
def __init__(self, enum, name: Optional[str] = None, mapper=None):
|
||||
super().__init__()
|
||||
self.enum = enum
|
||||
self.name = name or enum.__name__
|
||||
self.mapping = mapper(enum) if mapper is not None else self._mapper()
|
||||
|
||||
def _mapper(self):
|
||||
return {m: m.value[0] for m in self.enum}
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
registry.init_task(self.on_init)
|
||||
return self
|
||||
|
||||
async def on_init(self, registry: Registry):
|
||||
connector = registry._conn
|
||||
if connector is None:
|
||||
raise ValueError("Cannot initialise without connector!")
|
||||
connection = await connector.get_connection()
|
||||
if connection is None:
|
||||
raise ValueError("Cannot Init without connection.")
|
||||
info = await EnumInfo.fetch(connection, self.name)
|
||||
if info is None:
|
||||
raise ValueError(f"Enum {self.name} not found in database.")
|
||||
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
|
||||
45
src/data/base.py
Normal file
45
src/data/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Expression(Protocol):
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RawExpr(Expression):
|
||||
__slots__ = ('expr', 'values')
|
||||
|
||||
expr: sql.Composable
|
||||
values: tuple[Any, ...]
|
||||
|
||||
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
|
||||
self.expr = expr
|
||||
self.values = values
|
||||
|
||||
def as_tuple(self):
|
||||
return (self.expr, self.values)
|
||||
|
||||
@classmethod
|
||||
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
|
||||
"""
|
||||
Join a sequence of Expressions into a single RawExpr.
|
||||
"""
|
||||
tups = (
|
||||
expression.as_tuple()
|
||||
for expression in expressions
|
||||
)
|
||||
return cls.join_tuples(*tups, joiner=joiner)
|
||||
|
||||
@classmethod
|
||||
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
|
||||
exprs, values = zip(*tuples)
|
||||
expr = joiner.join(exprs)
|
||||
value = tuple(chain(*values))
|
||||
return cls(expr, value)
|
||||
155
src/data/columns.py
Normal file
155
src/data/columns.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
|
||||
from psycopg import sql
|
||||
from datetime import datetime
|
||||
|
||||
from .base import RawExpr, Expression
|
||||
from .conditions import Condition, Joiner
|
||||
from .table import Table
|
||||
|
||||
|
||||
class ColumnExpr(RawExpr):
|
||||
__slots__ = ()
|
||||
|
||||
def __lt__(self, obj) -> Condition:
|
||||
expr, values = self.as_tuple()
|
||||
|
||||
if isinstance(obj, Expression):
|
||||
# column < Expression
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
cond_exprs = (expr, Joiner.LT, obj_expr)
|
||||
cond_values = (*values, *obj_values)
|
||||
else:
|
||||
# column < Literal
|
||||
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
|
||||
cond_values = (*values, obj)
|
||||
|
||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __le__(self, obj) -> Condition:
|
||||
expr, values = self.as_tuple()
|
||||
|
||||
if isinstance(obj, Expression):
|
||||
# column <= Expression
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
cond_exprs = (expr, Joiner.LE, obj_expr)
|
||||
cond_values = (*values, *obj_values)
|
||||
else:
|
||||
# column <= Literal
|
||||
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
|
||||
cond_values = (*values, obj)
|
||||
|
||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __eq__(self, obj) -> Condition: # type: ignore[override]
|
||||
return Condition._expression_equality(self, obj)
|
||||
|
||||
def __ne__(self, obj) -> Condition: # type: ignore[override]
|
||||
return ~(self.__eq__(obj))
|
||||
|
||||
def __gt__(self, obj) -> Condition:
|
||||
return ~(self.__le__(obj))
|
||||
|
||||
def __ge__(self, obj) -> Condition:
|
||||
return ~(self.__lt__(obj))
|
||||
|
||||
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} + {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def __sub__(self, obj) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} - {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def __mul__(self, obj) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def CAST(self, target_type: sql.Composable):
|
||||
return ColumnExpr(
|
||||
sql.SQL("({}::{})").format(self.expr, target_type),
|
||||
self.values
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import RowModel
|
||||
|
||||
|
||||
class Column(ColumnExpr, Generic[T]):
|
||||
def __init__(self, name: Optional[str] = None,
|
||||
primary: bool = False, references: Optional['Column'] = None,
|
||||
type: Optional[Type[T]] = None):
|
||||
self.primary = primary
|
||||
self.references = references
|
||||
self.name: str = name # type: ignore
|
||||
self.owner: Optional['RowModel'] = None
|
||||
self._type = type
|
||||
|
||||
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
||||
self.values = ()
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
# Only allow setting the owner once
|
||||
self.name = self.name or name
|
||||
self.owner = owner
|
||||
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
|
||||
|
||||
@overload
|
||||
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
|
||||
...
|
||||
|
||||
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
|
||||
# Get value from row data or session
|
||||
if obj is None:
|
||||
return self
|
||||
else:
|
||||
return obj.data[self.name]
|
||||
|
||||
|
||||
class Integer(Column[int]):
|
||||
pass
|
||||
|
||||
|
||||
class String(Column[str]):
|
||||
pass
|
||||
|
||||
|
||||
class Bool(Column[bool]):
|
||||
pass
|
||||
|
||||
|
||||
class Timestamp(Column[datetime]):
|
||||
pass
|
||||
214
src/data/conditions.py
Normal file
214
src/data/conditions.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# from meta import sharding
|
||||
from typing import Any, Union
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
from .base import Expression, RawExpr
|
||||
|
||||
|
||||
"""
|
||||
A Condition is a "logical" database expression, intended for use in Where statements.
|
||||
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
|
||||
"""
|
||||
|
||||
NULL = None
|
||||
|
||||
|
||||
class Joiner(Enum):
|
||||
EQUALS = ('=', '!=')
|
||||
IS = ('IS', 'IS NOT')
|
||||
LIKE = ('LIKE', 'NOT LIKE')
|
||||
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
|
||||
IN = ('IN', 'NOT IN')
|
||||
LT = ('<', '>=')
|
||||
LE = ('<=', '>')
|
||||
NONE = ('', '')
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
|
||||
|
||||
def __init__(self,
|
||||
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
|
||||
values: tuple[Any, ...] = (), negated=False
|
||||
):
|
||||
self.expr1 = expr1
|
||||
self.joiner = joiner
|
||||
self.negated = negated
|
||||
self.expr2 = expr2
|
||||
self.values = values
|
||||
|
||||
def as_tuple(self):
|
||||
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
|
||||
if self.negated and self.joiner is Joiner.NONE:
|
||||
expr = sql.SQL("NOT ({})").format(expr)
|
||||
return (expr, self.values)
|
||||
|
||||
@classmethod
|
||||
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
|
||||
"""
|
||||
Construct a Condition from a sequence of Conditions,
|
||||
together with some explicit column conditions.
|
||||
"""
|
||||
# TODO: Consider adding a _table identifier here so we can identify implicit columns
|
||||
# Or just require subquery type conditions to always come from modelled tables.
|
||||
implicit_conditions = (
|
||||
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
|
||||
)
|
||||
return cls._and(*conditions, *implicit_conditions)
|
||||
|
||||
@classmethod
|
||||
def _and(cls, *conditions: 'Condition'):
|
||||
if not len(conditions):
|
||||
raise ValueError("Cannot combine 0 Conditions")
|
||||
if len(conditions) == 1:
|
||||
return conditions[0]
|
||||
|
||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||
cond_values = tuple(chain(*values))
|
||||
|
||||
return Condition(cond_expr, values=cond_values)
|
||||
|
||||
@classmethod
|
||||
def _or(cls, *conditions: 'Condition'):
|
||||
if not len(conditions):
|
||||
raise ValueError("Cannot combine 0 Conditions")
|
||||
if len(conditions) == 1:
|
||||
return conditions[0]
|
||||
|
||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||
cond_values = tuple(chain(*values))
|
||||
|
||||
return Condition(cond_expr, values=cond_values)
|
||||
|
||||
@classmethod
|
||||
def _not(cls, condition: 'Condition'):
|
||||
condition.negated = not condition.negated
|
||||
return condition
|
||||
|
||||
@classmethod
|
||||
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
|
||||
# TODO: Check if this supports sbqueries
|
||||
col_expr, col_values = column.as_tuple()
|
||||
|
||||
# TODO: Also support sql.SQL? For joins?
|
||||
if isinstance(value, Expression):
|
||||
# column = Expression
|
||||
value_expr, value_values = value.as_tuple()
|
||||
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
|
||||
cond_values = (*col_values, *value_values)
|
||||
elif isinstance(value, (tuple, list)):
|
||||
# column in (...)
|
||||
# TODO: Support expressions in value tuple?
|
||||
if not value:
|
||||
raise ValueError("Cannot create Condition from empty iterable!")
|
||||
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
|
||||
cond_exprs = (col_expr, Joiner.IN, value_expr)
|
||||
cond_values = (*col_values, *value)
|
||||
elif value is None:
|
||||
# column IS NULL
|
||||
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
|
||||
cond_values = col_values
|
||||
else:
|
||||
# column = Literal
|
||||
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
|
||||
cond_values = (*col_values, value)
|
||||
|
||||
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __invert__(self) -> 'Condition':
|
||||
self.negated = not self.negated
|
||||
return self
|
||||
|
||||
def __and__(self, condition: 'Condition') -> 'Condition':
|
||||
return self._and(self, condition)
|
||||
|
||||
def __or__(self, condition: 'Condition') -> 'Condition':
|
||||
return self._or(self, condition)
|
||||
|
||||
|
||||
# Helper method to simply condition construction
|
||||
def condition(*args, **kwargs) -> Condition:
|
||||
return Condition.construct(*args, **kwargs)
|
||||
|
||||
|
||||
# class NOT(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# if item:
|
||||
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
# values.extend(item)
|
||||
# else:
|
||||
# raise ValueError("Cannot check an empty iterable!")
|
||||
# else:
|
||||
# conditions.append("{}!={}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class GEQ(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# raise ValueError("Cannot apply GEQ condition to a list!")
|
||||
# else:
|
||||
# conditions.append("{} >= {}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class LEQ(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# raise ValueError("Cannot apply LEQ condition to a list!")
|
||||
# else:
|
||||
# conditions.append("{} <= {}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class Constant(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# conditions.append("{} {}".format(key, self.value))
|
||||
#
|
||||
#
|
||||
# class SHARDID(Condition):
|
||||
# __slots__ = ('shardid', 'shard_count')
|
||||
#
|
||||
# def __init__(self, shardid, shard_count):
|
||||
# self.shardid = shardid
|
||||
# self.shard_count = shard_count
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# if self.shard_count > 1:
|
||||
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
|
||||
# values.append(self.shardid)
|
||||
#
|
||||
#
|
||||
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
|
||||
#
|
||||
#
|
||||
# NULL = Constant('IS NULL')
|
||||
# NOTNULL = Constant('IS NOT NULL')
|
||||
58
src/data/connector.py
Normal file
58
src/data/connector.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
||||
import logging
|
||||
|
||||
import psycopg as psq
|
||||
|
||||
from .cursor import AsyncLoggingCursor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
row_factory = psq.rows.dict_row
|
||||
|
||||
|
||||
class Connector:
|
||||
cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, conn_args):
|
||||
self._conn_args = conn_args
|
||||
self.conn: psq.AsyncConnection = None
|
||||
|
||||
self.conn_hooks = []
|
||||
|
||||
async def get_connection(self) -> psq.AsyncConnection:
|
||||
"""
|
||||
Get the current active connection.
|
||||
This should never be cached outside of a transaction.
|
||||
"""
|
||||
# TODO: Reconnection logic?
|
||||
if not self.conn:
|
||||
raise ValueError("Attempting to get connection before initialisation!")
|
||||
return self.conn
|
||||
|
||||
async def connect(self) -> psq.AsyncConnection:
|
||||
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
||||
self.conn = await psq.AsyncConnection.connect(
|
||||
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
||||
)
|
||||
for hook in self.conn_hooks:
|
||||
await hook(self.conn)
|
||||
return self.conn
|
||||
|
||||
async def reconnect(self) -> psq.AsyncConnection:
|
||||
return await self.connect()
|
||||
|
||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||
"""
|
||||
Minimal decorator to register a coroutine to run on connect or reconnect.
|
||||
|
||||
Note that these are only run on connect and reconnect.
|
||||
If a hook is registered after connection, it will not be run.
|
||||
"""
|
||||
self.conn_hooks.append(coro)
|
||||
return coro
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Connectable(Protocol):
|
||||
def bind(self, connector: Connector):
|
||||
raise NotImplementedError
|
||||
42
src/data/cursor.py
Normal file
42
src/data/cursor.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from psycopg import AsyncCursor, sql
|
||||
from psycopg.abc import Query, Params
|
||||
from psycopg._encodings import pgconn_encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncLoggingCursor(AsyncCursor):
|
||||
def mogrify_query(self, query: Query):
|
||||
if isinstance(query, str):
|
||||
msg = query
|
||||
elif isinstance(query, (sql.SQL, sql.Composed)):
|
||||
msg = query.as_string(self)
|
||||
elif isinstance(query, bytes):
|
||||
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
|
||||
else:
|
||||
msg = repr(query)
|
||||
return msg
|
||||
|
||||
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
|
||||
if logging.DEBUG >= logger.getEffectiveLevel():
|
||||
msg = self.mogrify_query(query)
|
||||
logger.debug(
|
||||
"Executing query (%s) with values %s", msg, params,
|
||||
extra={'action': "Query Execute"}
|
||||
)
|
||||
try:
|
||||
return await super().execute(query, params=params, **kwargs)
|
||||
except Exception:
|
||||
msg = self.mogrify_query(query)
|
||||
logger.exception(
|
||||
"Exception during query execution. Query (%s) with parameters %s.",
|
||||
msg, params,
|
||||
extra={'action': "Query Execute"},
|
||||
stack_info=True
|
||||
)
|
||||
else:
|
||||
# TODO: Possibly log execution time
|
||||
pass
|
||||
46
src/data/database.py
Normal file
46
src/data/database.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import TypeVar
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
# from .cursor import AsyncLoggingCursor
|
||||
from .registry import Registry
|
||||
from .connector import Connector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Version = namedtuple('Version', ('version', 'time', 'author'))
|
||||
|
||||
T = TypeVar('T', bound=Registry)
|
||||
|
||||
|
||||
class Database(Connector):
|
||||
# cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.registries: dict[str, Registry] = {}
|
||||
|
||||
def load_registry(self, registry: T) -> T:
|
||||
logger.debug(
|
||||
f"Loading and binding registry '{registry.name}'.",
|
||||
extra={'action': f"Reg {registry.name}"}
|
||||
)
|
||||
registry.bind(self)
|
||||
self.registries[registry.name] = registry
|
||||
return registry
|
||||
|
||||
async def version(self) -> Version:
|
||||
"""
|
||||
Return the current schema version as a Version namedtuple.
|
||||
"""
|
||||
async with self.conn.cursor() as cursor:
|
||||
# Get last entry in version table, compare against desired version
|
||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Version(row['version'], row['time'], row['author'])
|
||||
else:
|
||||
# No versions in the database
|
||||
return Version(-1, None, None)
|
||||
320
src/data/models.py
Normal file
320
src/data/models.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import TypeVar, Type, Optional, Generic, Union
|
||||
# from typing_extensions import Self
|
||||
from weakref import WeakValueDictionary
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from psycopg.rows import DictRow
|
||||
|
||||
from .table import Table
|
||||
from .columns import Column
|
||||
from . import queries as q
|
||||
from .connector import Connector
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
RowT = TypeVar('RowT', bound='RowModel')
|
||||
|
||||
|
||||
class MISSING:
|
||||
__slots__ = ('oid',)
|
||||
|
||||
def __init__(self, oid):
|
||||
self.oid = oid
|
||||
|
||||
|
||||
class RowTable(Table, Generic[RowT]):
|
||||
__slots__ = (
|
||||
'model',
|
||||
)
|
||||
|
||||
def __init__(self, name, model: Type[RowT], **kwargs):
|
||||
super().__init__(name, **kwargs)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.model._columns_
|
||||
|
||||
@property
|
||||
def id_col(self):
|
||||
return self.model._key_
|
||||
|
||||
@property
|
||||
def row_cache(self):
|
||||
return self.model._cache_
|
||||
|
||||
def _many_query_adapter(self, *data):
|
||||
self.model._make_rows(*data)
|
||||
return data
|
||||
|
||||
def _single_query_adapter(self, *data):
|
||||
self.model._make_rows(*data)
|
||||
return data[0]
|
||||
|
||||
def _delete_query_adapter(self, *data):
|
||||
self.model._delete_rows(*data)
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
async def create_row(self, *args, **kwargs) -> RowT:
|
||||
data = await super().insert(*args, **kwargs)
|
||||
return self.model._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
||||
# TODO: Handle list of rowids here?
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self.model._make_rows,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
|
||||
WK = TypeVar('WK')
|
||||
WV = TypeVar('WV')
|
||||
|
||||
|
||||
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
||||
def __init__(self, ref_cache):
|
||||
self.ref_cache = ref_cache
|
||||
self.weak_cache = WeakValueDictionary()
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self.weak_cache[key]
|
||||
self.ref_cache[key] = value
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.weak_cache[key] = value
|
||||
self.ref_cache[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.weak_cache[key]
|
||||
try:
|
||||
del self.ref_cache[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.weak_cache
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.weak_cache)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.weak_cache)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def pop(self, key, default=None):
|
||||
if key in self:
|
||||
value = self[key]
|
||||
del self[key]
|
||||
else:
|
||||
value = default
|
||||
return value
|
||||
|
||||
|
||||
# TODO: Implement getitem and setitem, for dynamic column access
|
||||
class RowModel:
|
||||
__slots__ = ('data',)
|
||||
|
||||
_schema_: str = 'public'
|
||||
_tablename_: Optional[str] = None
|
||||
_columns_: dict[str, Column] = {}
|
||||
|
||||
# Cache to keep track of registered Rows
|
||||
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
||||
|
||||
_key_: tuple[str, ...] = ()
|
||||
_connector: Optional[Connector] = None
|
||||
_registry: Optional[Registry] = None
|
||||
|
||||
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
||||
table: RowTable
|
||||
|
||||
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
||||
"""
|
||||
Set table, _columns_, and _key_.
|
||||
"""
|
||||
if table is not None:
|
||||
cls._tablename_ = table
|
||||
|
||||
if cls._tablename_ is not None:
|
||||
columns = {}
|
||||
for key, value in cls.__dict__.items():
|
||||
if isinstance(value, Column):
|
||||
columns[key] = value
|
||||
|
||||
cls._columns_ = columns
|
||||
if not cls._key_:
|
||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
||||
if cls._cache_ is None:
|
||||
cls._cache_ = WeakValueDictionary()
|
||||
|
||||
def __new__(cls, data):
|
||||
# Registry pattern.
|
||||
# Ensure each rowid always refers to a single Model instance
|
||||
if data is not None:
|
||||
rowid = cls._id_from_data(data)
|
||||
|
||||
cache = cls._cache_
|
||||
|
||||
if (row := cache.get(rowid, None)) is not None:
|
||||
obj = row
|
||||
else:
|
||||
obj = cache[rowid] = super().__new__(cls)
|
||||
else:
|
||||
obj = super().__new__(cls)
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def as_tuple(cls):
|
||||
return (cls.table.identifier, ())
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
@classmethod
|
||||
def bind(cls, connector: Connector):
|
||||
if cls.table is None:
|
||||
raise ValueError("Cannot bind abstract RowModel")
|
||||
cls._connector = connector
|
||||
cls.table.bind(connector)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def attach_to(cls, registry: Registry):
|
||||
cls._registry = registry
|
||||
return cls
|
||||
|
||||
@property
|
||||
def _dict_(self):
|
||||
return {key: self.data[key] for key in self._key_}
|
||||
|
||||
@property
|
||||
def _rowid_(self):
|
||||
return tuple(self.data[key] for key in self._key_)
|
||||
|
||||
def __repr__(self):
|
||||
return "{}.{}({})".format(
|
||||
self.table.schema,
|
||||
self.table.name,
|
||||
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _id_from_data(cls, data):
|
||||
return tuple(data[key] for key in cls._key_)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_id(cls, rowid):
|
||||
return dict(zip(cls._key_, rowid))
|
||||
|
||||
@classmethod
|
||||
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
# TODO: Handle partial row data here somehow?
|
||||
rows = [cls(data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def _delete_rows(cls, *data_rows):
|
||||
"""
|
||||
Remove the given rows from cache, if they exist.
|
||||
May be extended to handle object deletion.
|
||||
"""
|
||||
cache = cls._cache_
|
||||
|
||||
for data_row in data_rows:
|
||||
rowid = cls._id_from_data(data_row)
|
||||
cache.pop(rowid, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
||||
return await cls.table.create_row(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
||||
return cls.table.fetch_rows_where(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls: Type[RowT], *rowid) -> Optional[RowT]:
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = cls._cache_.get(rowid, None)
|
||||
if row is None:
|
||||
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
||||
row = rows[0] if rows else None
|
||||
if row is None:
|
||||
cls._cache_[rowid] = cls(None)
|
||||
elif row.data is None:
|
||||
row = None
|
||||
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
async def fetch_or_create(cls, *rowid, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid:
|
||||
row = await cls.fetch(*rowid)
|
||||
else:
|
||||
rows = await cls.fetch_where(**kwargs).limit(1)
|
||||
row = rows[0] if rows else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid:
|
||||
creation_kwargs.update(cls._dict_from_id(rowid))
|
||||
row = await cls.create(**creation_kwargs)
|
||||
return row
|
||||
|
||||
async def refresh(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Refresh this Row from data.
|
||||
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
rows = await self.table.select_where(**self._dict_)
|
||||
if not rows:
|
||||
return None
|
||||
else:
|
||||
self.data = rows[0]
|
||||
return self
|
||||
|
||||
async def update(self: RowT, **values) -> Optional[RowT]:
|
||||
"""
|
||||
Update this Row with the given values.
|
||||
|
||||
Internally passes the provided `values` to the `update` Query.
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
||||
if not data:
|
||||
return None
|
||||
else:
|
||||
return data[0]
|
||||
|
||||
async def delete(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Delete this Row.
|
||||
"""
|
||||
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
||||
return data[0] if data is not None else None
|
||||
592
src/data/queries.py
Normal file
592
src/data/queries.py
Normal file
@@ -0,0 +1,592 @@
|
||||
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from psycopg import AsyncConnection, AsyncCursor
|
||||
from psycopg import sql
|
||||
from psycopg.rows import DictRow
|
||||
|
||||
import logging
|
||||
|
||||
from .conditions import Condition
|
||||
from .base import Expression, RawExpr
|
||||
from .connector import Connector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TQueryT = TypeVar('TQueryT', bound='TableQuery')
|
||||
SQueryT = TypeVar('SQueryT', bound='Select')
|
||||
|
||||
QueryResult = TypeVar('QueryResult')
|
||||
|
||||
|
||||
class Query(Generic[QueryResult]):
|
||||
"""
|
||||
ABC for an executable query statement.
|
||||
"""
|
||||
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
|
||||
|
||||
_adapter: Callable[..., QueryResult]
|
||||
|
||||
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
|
||||
self.connector: Optional[Connector] = connector
|
||||
self.conn: Optional[AsyncConnection] = conn
|
||||
self.cursor: Optional[AsyncCursor] = cursor
|
||||
|
||||
if row_adapter is not None:
|
||||
self._adapter = row_adapter
|
||||
else:
|
||||
self._adapter = self._no_adapter
|
||||
|
||||
self.result: Optional[QueryResult] = None
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self.connector = connector
|
||||
return self
|
||||
|
||||
def with_cursor(self, cursor: AsyncCursor):
|
||||
self.cursor = cursor
|
||||
return self
|
||||
|
||||
def with_connection(self, conn: AsyncConnection):
|
||||
self.conn = conn
|
||||
return self
|
||||
|
||||
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def with_adapter(self, callable: Callable[..., QueryResult]):
|
||||
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
|
||||
# For this to work cleanly, callable should have arg type of QR1, not any
|
||||
self._adapter = callable
|
||||
return self
|
||||
|
||||
def with_no_adapter(self):
|
||||
"""
|
||||
Sets the adapater to the identity.
|
||||
"""
|
||||
self._adapter = self._no_adapter
|
||||
return self
|
||||
|
||||
def one(self):
|
||||
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
||||
return self
|
||||
|
||||
def build(self) -> Expression:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
|
||||
query, values = self.build().as_tuple()
|
||||
# TODO: Move logging out to a custom cursor
|
||||
# logger.debug(
|
||||
# f"Executing query ({query.as_string(cursor)}) with values {values}",
|
||||
# extra={'action': "Query"}
|
||||
# )
|
||||
await cursor.execute(sql.Composed((query,)), values)
|
||||
data = await cursor.fetchall()
|
||||
self.result = self._adapter(*data)
|
||||
return self.result
|
||||
|
||||
async def execute(self, cursor=None) -> QueryResult:
|
||||
"""
|
||||
Execute the query, optionally with the provided cursor, and return the result rows.
|
||||
If no cursor is provided, and no cursor has been set with `with_cursor`,
|
||||
the execution will create a new cursor from the connection and close it automatically.
|
||||
"""
|
||||
# Create a cursor if possible
|
||||
cursor = cursor if cursor is not None else self.cursor
|
||||
if self.cursor is None:
|
||||
if self.conn is None:
|
||||
if self.connector is None:
|
||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||
else:
|
||||
conn = await self.connector.get_connection()
|
||||
else:
|
||||
conn = self.conn
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
data = await self._execute(cursor)
|
||||
return data
|
||||
|
||||
def __await__(self):
|
||||
return self.execute().__await__()
|
||||
|
||||
|
||||
class TableQuery(Query[QueryResult]):
|
||||
"""
|
||||
ABC for an executable query statement expected to be run on a single table.
|
||||
"""
|
||||
__slots__ = (
|
||||
'tableid',
|
||||
'condition', '_extra', '_limit', '_order', '_joins'
|
||||
)
|
||||
|
||||
def __init__(self, tableid, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tableid: sql.Identifier = tableid
|
||||
|
||||
def options(self, **kwargs):
|
||||
"""
|
||||
Set some query options.
|
||||
Default implementation does nothing.
|
||||
Should be overridden to provide specific options.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class WhereMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.condition: Optional[Condition] = None
|
||||
|
||||
def where(self, *args: Condition, **kwargs):
|
||||
"""
|
||||
Add a Condition to the query.
|
||||
Position arguments should be Conditions,
|
||||
and keyword arguments should be of the form `column=Value`,
|
||||
where Value may be a Value-type or a literal value.
|
||||
All provided Conditions will be and-ed together to create a new Condition.
|
||||
TODO: Maybe just pass this verbatim to a condition.
|
||||
"""
|
||||
if args or kwargs:
|
||||
condition = Condition.construct(*args, **kwargs)
|
||||
if self.condition is not None:
|
||||
condition = self.condition & condition
|
||||
|
||||
self.condition = condition
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def _where_section(self) -> Optional[Expression]:
|
||||
if self.condition is not None:
|
||||
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class JOINTYPE(Enum):
|
||||
LEFT = sql.SQL('LEFT JOIN')
|
||||
RIGHT = sql.SQL('RIGHT JOIN')
|
||||
INNER = sql.SQL('INNER JOIN')
|
||||
OUTER = sql.SQL('OUTER JOIN')
|
||||
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
||||
|
||||
|
||||
class JoinMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
# TODO: Remember to add join slots to TableQuery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._joins: list[Expression] = []
|
||||
|
||||
def join(self,
|
||||
target: Union[str, Expression],
|
||||
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
||||
join_type: JOINTYPE = JOINTYPE.INNER,
|
||||
natural=False):
|
||||
available = (on is not None) + (using is not None) + natural
|
||||
if available == 0:
|
||||
raise ValueError("No conditions given for Query Join")
|
||||
if available > 1:
|
||||
raise ValueError("Exactly one join format must be given for Query Join")
|
||||
|
||||
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
||||
if isinstance(target, str):
|
||||
sections.append((sql.Identifier(target), ()))
|
||||
else:
|
||||
sections.append(target.as_tuple())
|
||||
|
||||
if on is not None:
|
||||
sections.append((sql.SQL('ON'), ()))
|
||||
sections.append(on.as_tuple())
|
||||
elif using is not None:
|
||||
sections.append((sql.SQL('USING'), ()))
|
||||
if isinstance(using, Expression):
|
||||
sections.append(using.as_tuple())
|
||||
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
|
||||
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
|
||||
sections.append((cols, ()))
|
||||
else:
|
||||
raise ValueError("Unrecognised 'using' type.")
|
||||
elif natural:
|
||||
sections.insert(0, (sql.SQL('NATURAL'), ()))
|
||||
|
||||
expr = RawExpr.join_tuples(*sections)
|
||||
self._joins.append(expr)
|
||||
return self
|
||||
|
||||
def leftjoin(self, *args, **kwargs):
|
||||
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
||||
|
||||
@property
|
||||
def _join_section(self) -> Optional[Expression]:
|
||||
if self._joins:
|
||||
return RawExpr.join(*self._joins)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ExtraMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._extra: Optional[Expression] = None
|
||||
|
||||
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
|
||||
"""
|
||||
Add an extra string, and optionally values, to this query.
|
||||
The extra string is inserted after any condition, and before the limit.
|
||||
"""
|
||||
extra_expr = RawExpr(extra, values)
|
||||
if self._extra is not None:
|
||||
extra_expr = RawExpr.join(self._extra, extra_expr)
|
||||
self._extra = extra_expr
|
||||
return self
|
||||
|
||||
@property
|
||||
def _extra_section(self) -> Optional[Expression]:
|
||||
if self._extra is None:
|
||||
return None
|
||||
else:
|
||||
return self._extra
|
||||
|
||||
|
||||
class LimitMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._limit: Optional[int] = None
|
||||
|
||||
def limit(self, limit: int):
|
||||
"""
|
||||
Add a limit to this query.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
@property
|
||||
def _limit_section(self) -> Optional[Expression]:
|
||||
if self._limit is not None:
|
||||
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ORDER(Enum):
|
||||
ASC = sql.SQL('ASC')
|
||||
DESC = sql.SQL('DESC')
|
||||
|
||||
|
||||
class NULLS(Enum):
|
||||
FIRST = sql.SQL('NULLS FIRST')
|
||||
LAST = sql.SQL('NULLS LAST')
|
||||
|
||||
|
||||
class OrderMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._order: list[Expression] = []
|
||||
|
||||
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
|
||||
"""
|
||||
Add a single sort expression to the query.
|
||||
This method stacks.
|
||||
"""
|
||||
if isinstance(expr, Expression):
|
||||
string, values = expr.as_tuple()
|
||||
else:
|
||||
string = sql.Identifier(expr)
|
||||
values = ()
|
||||
|
||||
parts = [string]
|
||||
if direction is not None:
|
||||
parts.append(direction.value)
|
||||
if nulls is not None:
|
||||
parts.append(nulls.value)
|
||||
|
||||
order_string = sql.SQL(' ').join(parts)
|
||||
self._order.append(RawExpr(order_string, values))
|
||||
return self
|
||||
|
||||
@property
|
||||
def _order_section(self) -> Optional[Expression]:
|
||||
if self._order:
|
||||
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
|
||||
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
|
||||
return expr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Insert(ExtraMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Query type representing a table insert query.
|
||||
"""
|
||||
# TODO: Support ON CONFLICT for upserts
|
||||
__slots__ = ('_columns', '_values', '_conflict')
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._columns: tuple[str, ...] = ()
|
||||
self._values: tuple[tuple[Any, ...], ...] = ()
|
||||
self._conflict: Optional[Expression] = None
|
||||
|
||||
def insert(self, columns, *values):
|
||||
"""
|
||||
Insert the given data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: tuple[str]
|
||||
Tuple of column names to insert.
|
||||
|
||||
values: tuple[tuple[Any, ...], ...]
|
||||
Tuple of values to insert, corresponding to the columns.
|
||||
"""
|
||||
if not values:
|
||||
raise ValueError("Cannot insert zero rows.")
|
||||
if len(values[0]) != len(columns):
|
||||
raise ValueError("Number of columns does not match length of values.")
|
||||
|
||||
self._columns = columns
|
||||
self._values = values
|
||||
return self
|
||||
|
||||
def on_conflict(self, ignore=False):
|
||||
# TODO lots more we can do here
|
||||
# Maybe return a Conflict object that can chain itself (not the query)
|
||||
if ignore:
|
||||
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
||||
return self
|
||||
|
||||
@property
|
||||
def _conflict_section(self) -> Optional[Expression]:
|
||||
if self._conflict is not None:
|
||||
e, v = self._conflict.as_tuple()
|
||||
expr = RawExpr(
|
||||
sql.SQL("ON CONFLICT {}").format(
|
||||
e
|
||||
),
|
||||
v
|
||||
)
|
||||
return expr
|
||||
return None
|
||||
|
||||
def build(self):
|
||||
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
||||
single_value_str = sql.SQL('({})').format(
|
||||
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
|
||||
)
|
||||
values_str = sql.SQL(',').join(single_value_str * len(self._values))
|
||||
|
||||
# TODO: Check efficiency of inserting multiple values like this
|
||||
# Also implement a Copy query
|
||||
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
||||
table=self.tableid,
|
||||
columns=columns,
|
||||
values_str=values_str
|
||||
)
|
||||
|
||||
sections = [
|
||||
RawExpr(base, tuple(chain(*self._values))),
|
||||
self._conflict_section,
|
||||
self._extra_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Select rows from a table matching provided conditions.
|
||||
"""
|
||||
__slots__ = ('_columns',)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._columns: tuple[Expression, ...] = ()
|
||||
|
||||
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
|
||||
"""
|
||||
Set the columns and expressions to select.
|
||||
If none are given, selects all columns.
|
||||
"""
|
||||
cols: List[Expression] = []
|
||||
if columns:
|
||||
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
|
||||
if exprs:
|
||||
for name, expr in exprs.items():
|
||||
if isinstance(expr, str):
|
||||
cols.append(
|
||||
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
|
||||
)
|
||||
elif isinstance(expr, sql.Composable):
|
||||
cols.append(
|
||||
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
|
||||
)
|
||||
elif isinstance(expr, Expression):
|
||||
value_expr, value_values = expr.as_tuple()
|
||||
cols.append(RawExpr(
|
||||
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
|
||||
value_values
|
||||
))
|
||||
if cols:
|
||||
self._columns = (*self._columns, *cols)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
if not self._columns:
|
||||
columns, columns_values = sql.SQL('*'), ()
|
||||
else:
|
||||
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
|
||||
|
||||
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
||||
columns=columns,
|
||||
table=self.tableid
|
||||
)
|
||||
|
||||
sections = [
|
||||
RawExpr(base, columns_values),
|
||||
self._join_section,
|
||||
self._where_section,
|
||||
self._extra_section,
|
||||
self._order_section,
|
||||
self._limit_section,
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Query type representing a table delete query.
|
||||
"""
|
||||
# TODO: Cascade option for delete, maybe other options
|
||||
# TODO: Require a where unless specifically disabled, for safety
|
||||
|
||||
def build(self):
|
||||
base = sql.SQL("DELETE FROM {table}").format(
|
||||
table=self.tableid,
|
||||
)
|
||||
sections = [
|
||||
RawExpr(base),
|
||||
self._where_section,
|
||||
self._extra_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
||||
__slots__ = (
|
||||
'_set',
|
||||
)
|
||||
# TODO: Again, require a where unless specifically disabled
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._set: List[Expression] = []
|
||||
|
||||
def set(self, **column_values: Union[Any, Expression]):
|
||||
exprs: List[Expression] = []
|
||||
for name, value in column_values.items():
|
||||
if isinstance(value, Expression):
|
||||
value_tup = value.as_tuple()
|
||||
else:
|
||||
value_tup = (sql.Placeholder(), (value,))
|
||||
|
||||
exprs.append(
|
||||
RawExpr.join_tuples(
|
||||
(sql.Identifier(name), ()),
|
||||
value_tup,
|
||||
joiner=sql.SQL(' = ')
|
||||
)
|
||||
)
|
||||
self._set.extend(exprs)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
if not self._set:
|
||||
raise ValueError("No columns provided to update.")
|
||||
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
||||
|
||||
base = sql.SQL("UPDATE {table} SET {set}").format(
|
||||
table=self.tableid,
|
||||
set=set_expr
|
||||
)
|
||||
sections = [
|
||||
RawExpr(base, set_values),
|
||||
self._where_section,
|
||||
self._extra_section,
|
||||
self._limit_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
# async def upsert(cursor, table, constraint, **values):
|
||||
# """
|
||||
# Insert or on conflict update.
|
||||
# """
|
||||
# valuedict = values
|
||||
# keys, values = zip(*values.items())
|
||||
#
|
||||
# key_str = _format_insertkeys(keys)
|
||||
# value_str, values = _format_insertvalues(values)
|
||||
# update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||
#
|
||||
# if not isinstance(constraint, str):
|
||||
# constraint = ", ".join(constraint)
|
||||
#
|
||||
# await cursor.execute(
|
||||
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||
# table, key_str, value_str, constraint, update_key_str
|
||||
# ),
|
||||
# tuple((*values, *update_key_values))
|
||||
# )
|
||||
# return await cursor.fetchone()
|
||||
|
||||
|
||||
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
|
||||
# cursor = cursor or conn.cursor()
|
||||
#
|
||||
# # TODO: executemany or copy syntax now
|
||||
# return execute_values(
|
||||
# cursor,
|
||||
# """
|
||||
# UPDATE {table}
|
||||
# SET {set_clause}
|
||||
# FROM (VALUES {cast_row}%s)
|
||||
# AS {temp_table}
|
||||
# WHERE {where_clause}
|
||||
# RETURNING *
|
||||
# """.format(
|
||||
# table=table,
|
||||
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
||||
# cast_row=cast_row + ',' if cast_row else '',
|
||||
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
||||
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
||||
# ),
|
||||
# values,
|
||||
# fetch=True
|
||||
# )
|
||||
102
src/data/registry.py
Normal file
102
src/data/registry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Protocol, runtime_checkable, Optional
|
||||
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .connector import Connector, Connectable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _Attachable(Connectable, Protocol):
|
||||
def attach_to(self, registry: 'Registry'):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Registry:
|
||||
_attached: list[_Attachable] = []
|
||||
_name: Optional[str] = None
|
||||
|
||||
def __init_subclass__(cls, name=None):
|
||||
attached = []
|
||||
for _, member in cls.__dict__.items():
|
||||
if isinstance(member, _Attachable):
|
||||
attached.append(member)
|
||||
cls._attached = attached
|
||||
cls._name = name or cls.__name__
|
||||
|
||||
def __init__(self, name=None):
|
||||
self._conn: Optional[Connector] = None
|
||||
self.name: str = name if name is not None else self._name
|
||||
if self.name is None:
|
||||
raise ValueError("A Registry must have a name!")
|
||||
|
||||
self.init_tasks = []
|
||||
|
||||
for member in self._attached:
|
||||
member.attach_to(self)
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self._conn = connector
|
||||
for child in self._attached:
|
||||
child.bind(connector)
|
||||
|
||||
def attach(self, attachable):
|
||||
self._attached.append(attachable)
|
||||
if self._conn is not None:
|
||||
attachable.bind(self._conn)
|
||||
return attachable
|
||||
|
||||
def init_task(self, coro):
|
||||
"""
|
||||
Initialisation tasks are run to setup the registry state.
|
||||
These tasks will be run in the event loop, after connection to the database.
|
||||
These tasks should be idempotent, as they may be run on reload and reconnect.
|
||||
"""
|
||||
self.init_tasks.append(coro)
|
||||
return coro
|
||||
|
||||
async def init(self):
|
||||
for task in self.init_tasks:
|
||||
await task(self)
|
||||
return self
|
||||
|
||||
|
||||
class AttachableClass:
|
||||
"""ABC for a default implementation of an Attachable class."""
|
||||
|
||||
_connector: Optional[Connector] = None
|
||||
_registry: Optional[Registry] = None
|
||||
|
||||
@classmethod
|
||||
def bind(cls, connector: Connector):
|
||||
cls._connector = connector
|
||||
connector.connect_hook(cls.on_connect)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def attach_to(cls, registry: Registry):
|
||||
cls._registry = registry
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
async def on_connect(cls, connection: AsyncConnection):
|
||||
pass
|
||||
|
||||
|
||||
class Attachable:
|
||||
"""ABC for a default implementation of an Attachable object."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._connector: Optional[Connector] = None
|
||||
self._registry: Optional[Registry] = None
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self._connector = connector
|
||||
connector.connect_hook(self.on_connect)
|
||||
return self
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
return self
|
||||
|
||||
async def on_connect(self, connection: AsyncConnection):
|
||||
pass
|
||||
95
src/data/table.py
Normal file
95
src/data/table.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Optional
|
||||
from psycopg.rows import DictRow
|
||||
from psycopg import sql
|
||||
|
||||
from . import queries as q
|
||||
from .connector import Connector
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
class Table:
|
||||
"""
|
||||
Transparent interface to a single table structure in the database.
|
||||
Contains standard methods to access the table.
|
||||
"""
|
||||
|
||||
def __init__(self, name, *args, schema='public', **kwargs):
|
||||
self.name: str = name
|
||||
self.schema: str = schema
|
||||
self.connector: Connector = None
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self.schema == 'public':
|
||||
return sql.Identifier(self.name)
|
||||
else:
|
||||
return sql.Identifier(self.schema, self.name)
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self.connector = connector
|
||||
return self
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
return self
|
||||
|
||||
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
|
||||
if data:
|
||||
return data[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self._single_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
||||
return q.Update(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
||||
return q.Delete(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def insert(self, **column_values) -> q.Insert[DictRow]:
|
||||
return q.Insert(
|
||||
self.identifier,
|
||||
row_adapter=self._single_query_adapter,
|
||||
connector=self.connector
|
||||
).insert(column_values.keys(), column_values.values())
|
||||
|
||||
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
||||
return q.Insert(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).insert(*args, **kwargs)
|
||||
|
||||
# def update_many(self, *args, **kwargs):
|
||||
# with self.conn:
|
||||
# return update_many(self.identifier, *args, **kwargs)
|
||||
|
||||
# def upsert(self, *args, **kwargs):
|
||||
# return upsert(self.identifier, *args, **kwargs)
|
||||
1
src/gui
Submodule
1
src/gui
Submodule
Submodule src/gui added at 76659f1193
200
src/meta/LionBot.py
Normal file
200
src/meta/LionBot.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.utils import MISSING
|
||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
|
||||
from .config import Conf
|
||||
from .logger import logging_context, log_context, log_action_stack
|
||||
from .context import context
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
from .errors import HandledException, SafeCancellation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core import CoreCog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionBot(Bot):
|
||||
def __init__(
|
||||
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||
initial_extensions: List[str], web_client: ClientSession, app_ipc,
|
||||
testing_guilds: List[int] = [], translator=None, **kwargs
|
||||
):
|
||||
kwargs.setdefault('tree_cls', LionTree)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.web_client = web_client
|
||||
self.testing_guilds = testing_guilds
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
self.app_ipc = app_ipc
|
||||
self.core: Optional['CoreCog'] = None
|
||||
self.translator = translator
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
log_context.set(f"APP: {self.application_id}")
|
||||
await self.app_ipc.connect()
|
||||
|
||||
if self.translator is not None:
|
||||
await self.tree.set_translator(self.translator)
|
||||
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(extension)
|
||||
|
||||
for guildid in self.testing_guilds:
|
||||
guild = discord.Object(guildid)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
|
||||
async def add_cog(self, cog: Cog, **kwargs):
|
||||
with logging_context(action=f"Attach {cog.__cog_name__}"):
|
||||
logger.info(f"Attaching Cog {cog.__cog_name__}")
|
||||
await super().add_cog(cog, **kwargs)
|
||||
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
|
||||
|
||||
async def load_extension(self, name, *, package=None, **kwargs):
|
||||
with logging_context(action=f"Load {name.strip('.')}"):
|
||||
logger.info(f"Loading extension {name} in package {package}.")
|
||||
await super().load_extension(name, package=package, **kwargs)
|
||||
logger.debug(f"Loaded extension {name} in package {package}.")
|
||||
|
||||
async def start(self, token: str, *, reconnect: bool = True):
|
||||
with logging_context(action="Login"):
|
||||
await self.login(token)
|
||||
with logging_context(stack=["Running"]):
|
||||
await self.connect(reconnect=reconnect)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info(
|
||||
f"Logged in as {self.application.name}\n"
|
||||
f"Application id {self.application.id}\n"
|
||||
f"Shard Talk identifier {self.shardname}\n"
|
||||
"------------------------------\n"
|
||||
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||
"------------------------------\n"
|
||||
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||
"Ready to take commands!\n",
|
||||
extra={'action': 'Ready'}
|
||||
)
|
||||
|
||||
async def get_context(self, origin, /, *, cls=MISSING):
|
||||
if cls is MISSING:
|
||||
cls = LionContext
|
||||
ctx = await super().get_context(origin, cls=cls)
|
||||
context.set(ctx)
|
||||
return ctx
|
||||
|
||||
async def on_command(self, ctx: LionContext):
|
||||
logger.info(
|
||||
f"Executing command '{ctx.command.qualified_name}' "
|
||||
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
|
||||
f"with arguments {ctx.args} and kwargs {ctx.kwargs}.",
|
||||
extra={'with_ctx': True}
|
||||
)
|
||||
|
||||
async def on_command_error(self, ctx, exception):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
cmd_str = str(ctx.command)
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict()
|
||||
try:
|
||||
raise exception
|
||||
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||
try:
|
||||
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||
original = exception.original.original
|
||||
raise original
|
||||
else:
|
||||
original = exception.original
|
||||
raise original
|
||||
except HandledException:
|
||||
pass
|
||||
except SafeCancellation:
|
||||
if original.msg:
|
||||
try:
|
||||
await ctx.error_reply(original.msg)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Caught a safe cancellation: {original.details}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.Forbidden:
|
||||
# Unknown uncaught Forbidden
|
||||
try:
|
||||
# Attempt a general error reply
|
||||
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||
except Exception:
|
||||
# We can't send anything at all. Exit quietly, but log.
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except discord.HTTPException:
|
||||
logger.warning(
|
||||
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
|
||||
exc_info=True,
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
error_embed = discord.Embed(title="Something went wrong!")
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred while processing your command!\n"
|
||||
"Our development team has been notified, and the issue should be fixed soon.\n"
|
||||
"If the error persists, please contact our support team and give them the following number: "
|
||||
f"`{ctx.interaction.id}`"
|
||||
)
|
||||
|
||||
try:
|
||||
await ctx.error_reply(embed=error_embed)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
exception.original = HandledException(exception.original)
|
||||
except CheckFailure:
|
||||
logger.debug(
|
||||
f"Command failed check: {exception}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
try:
|
||||
await ctx.error_rely(exception.message)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
# Completely unknown exception outside of command invocation!
|
||||
# Something is very wrong here, don't attempt user interaction.
|
||||
logger.exception(
|
||||
f"Caught an unknown top-level exception while executing: {cmd_str}",
|
||||
extra={'action': 'BotError', 'with_ctx': True}
|
||||
)
|
||||
|
||||
def add_command(self, command):
|
||||
if hasattr(command, '_placeholder_group_'):
|
||||
return
|
||||
super().add_command(command)
|
||||
58
src/meta/LionCog.py
Normal file
58
src/meta/LionCog.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import Any
|
||||
|
||||
from discord.ext.commands import Cog
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
|
||||
class LionCog(Cog):
|
||||
# A set of other cogs that this cog depends on
|
||||
depends_on: set['LionCog'] = set()
|
||||
_placeholder_groups_: set[str]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
cls._placeholder_groups_ = set()
|
||||
|
||||
for base in reversed(cls.__mro__):
|
||||
for elem, value in base.__dict__.items():
|
||||
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
||||
cls._placeholder_groups_.add(value.name)
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any):
|
||||
# Patch to ensure no placeholder groups are in the command list
|
||||
self = super().__new__(cls)
|
||||
self.__cog_commands__ = [
|
||||
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
|
||||
]
|
||||
return self
|
||||
|
||||
async def _inject(self, bot, *args, **kwargs):
|
||||
if self.depends_on:
|
||||
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||
|
||||
return await super()._inject(bot, *args, *kwargs)
|
||||
|
||||
@classmethod
|
||||
def placeholder_group(cls, group: cmds.HybridGroup):
|
||||
group._placeholder_group_ = True
|
||||
return group
|
||||
|
||||
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
|
||||
"""
|
||||
Crossload a placeholder group's commands into the target group
|
||||
"""
|
||||
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
|
||||
raise ValueError("Placeholder and target groups my be HypridGroups.")
|
||||
if placeholder_group.name not in self._placeholder_groups_:
|
||||
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
|
||||
if target_group.app_command is None:
|
||||
raise ValueError("Target group has no app_command to crossload into.")
|
||||
|
||||
for command in placeholder_group.commands:
|
||||
placeholder_group.remove_command(command.name)
|
||||
target_group.remove_command(command.name)
|
||||
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
|
||||
command.app_command = acmd
|
||||
target_group.add_command(command)
|
||||
190
src/meta/LionContext.py
Normal file
190
src/meta/LionContext.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import types
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.ext.commands import Context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
from core.lion import Lion
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
Stuff that might be useful to implement (see cmdClient):
|
||||
sent_messages cache
|
||||
tasks cache
|
||||
error reply
|
||||
usage
|
||||
interaction cache
|
||||
View cache?
|
||||
setting access
|
||||
"""
|
||||
|
||||
|
||||
FlatContext = namedtuple(
|
||||
'FlatContext',
|
||||
('message',
|
||||
'interaction',
|
||||
'guild',
|
||||
'author',
|
||||
'alias',
|
||||
'prefix',
|
||||
'failed')
|
||||
)
|
||||
|
||||
|
||||
class LionContext(Context['LionBot']):
|
||||
"""
|
||||
Represents the context a command is invoked under.
|
||||
|
||||
Extends Context to add Lion-specific methods and attributes.
|
||||
Also adds several contextual wrapped utilities for simpler user during command invocation.
|
||||
"""
|
||||
alion: 'Lion'
|
||||
|
||||
def __repr__(self):
|
||||
parts = {}
|
||||
if self.interaction is not None:
|
||||
parts['iid'] = self.interaction.id
|
||||
parts['itype'] = f"\"{self.interaction.type.name}\""
|
||||
if self.message is not None:
|
||||
parts['mid'] = self.message.id
|
||||
if self.author is not None:
|
||||
parts['uid'] = self.author.id
|
||||
parts['uname'] = f"\"{self.author.name}\""
|
||||
if self.channel is not None:
|
||||
parts['cid'] = self.channel.id
|
||||
parts['cname'] = f"\"{self.channel.name}\""
|
||||
if self.guild is not None:
|
||||
parts['gid'] = self.guild.id
|
||||
parts['gname'] = f"\"{self.guild.name}\""
|
||||
if self.command is not None:
|
||||
parts['cmd'] = f"\"{self.command.qualified_name}\""
|
||||
if self.invoked_with is not None:
|
||||
parts['alias'] = f"\"{self.invoked_with}\""
|
||||
if self.command_failed:
|
||||
parts['failed'] = self.command_failed
|
||||
|
||||
return "<LionContext: {}>".format(
|
||||
' '.join(f"{name}={value}" for name, value in parts.items())
|
||||
)
|
||||
|
||||
def flatten(self):
|
||||
"""Flat pure-data context information, for caching and logging."""
|
||||
return FlatContext(
|
||||
self.message.id,
|
||||
self.interaction.id if self.interaction is not None else None,
|
||||
self.guild.id if self.guild is not None else None,
|
||||
self.author.id if self.author is not None else None,
|
||||
self.channel.id if self.channel is not None else None,
|
||||
self.invoked_with,
|
||||
self.prefix,
|
||||
self.command_failed
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def util(cls, util_func):
|
||||
"""
|
||||
Decorator to make a utility function available as a Context instance method.
|
||||
"""
|
||||
setattr(cls, util_func.__name__, util_func)
|
||||
logger.debug(f"Attached context utility function: {util_func.__name__}")
|
||||
return util_func
|
||||
|
||||
@classmethod
|
||||
def wrappable_util(cls, util_func):
|
||||
"""
|
||||
Decorator to add a Wrappable utility function as a Context instance method.
|
||||
"""
|
||||
wrapped = Wrappable(util_func)
|
||||
setattr(cls, util_func.__name__, wrapped)
|
||||
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
|
||||
return wrapped
|
||||
|
||||
async def error_reply(self, content: Optional[str] = None, **kwargs):
|
||||
if content and 'embed' not in kwargs:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=content
|
||||
)
|
||||
kwargs['embed'] = embed
|
||||
content = None
|
||||
|
||||
# Expect this may be run in highly unusual circumstances.
|
||||
# This should never error, or at least handle all errors.
|
||||
try:
|
||||
await self.reply(content=content, **kwargs)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Unknown exception in 'error_reply'.",
|
||||
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
|
||||
)
|
||||
|
||||
|
||||
class Wrappable:
|
||||
__slots__ = ('_func', 'wrappers')
|
||||
|
||||
def __init__(self, func):
|
||||
self._func = func
|
||||
self.wrappers = None
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self._func.__name__
|
||||
|
||||
def add_wrapper(self, func, name=None):
|
||||
self.wrappers = self.wrappers or {}
|
||||
name = name or func.__name__
|
||||
self.wrappers[name] = func
|
||||
logger.debug(
|
||||
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Wrap Util"}
|
||||
)
|
||||
|
||||
def remove_wrapper(self, name):
|
||||
if not self.wrappers or name not in self.wrappers:
|
||||
raise ValueError(
|
||||
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
|
||||
)
|
||||
self.wrappers.pop(name)
|
||||
logger.debug(
|
||||
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
|
||||
extra={'action': "Unwrap Util"}
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.wrappers:
|
||||
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
|
||||
else:
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
def _wrapped(self, iter_wraps):
|
||||
next_wrap = next(iter_wraps, None)
|
||||
if next_wrap:
|
||||
def _func(*args, **kwargs):
|
||||
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
|
||||
else:
|
||||
_func = self._func
|
||||
return _func
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return types.MethodType(self, instance)
|
||||
|
||||
|
||||
LionContext.reply = Wrappable(LionContext.reply)
|
||||
|
||||
|
||||
# @LionContext.reply.add_wrapper
|
||||
# async def think(func, ctx, *args, **kwargs):
|
||||
# await ctx.channel.send("thinking")
|
||||
# await func(ctx, *args, **kwargs)
|
||||
79
src/meta/LionTree.py
Normal file
79
src/meta/LionTree.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
|
||||
from discord import Interaction
|
||||
from discord.app_commands import CommandTree
|
||||
from discord.app_commands.errors import AppCommandError, CommandInvokeError
|
||||
from discord.enums import InteractionType
|
||||
from discord.app_commands.namespace import Namespace
|
||||
|
||||
from .logger import logging_context
|
||||
from .errors import SafeCancellation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LionTree(CommandTree):
|
||||
async def on_error(self, interaction, error) -> None:
|
||||
try:
|
||||
if isinstance(error, CommandInvokeError):
|
||||
raise error.original
|
||||
else:
|
||||
raise error
|
||||
except SafeCancellation:
|
||||
# Assume this has already been handled
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
|
||||
|
||||
async def _call(self, interaction):
|
||||
with logging_context(context=f"iid: {interaction.id}"):
|
||||
if not await self.interaction_check(interaction):
|
||||
interaction.command_failed = True
|
||||
return
|
||||
|
||||
data = interaction.data # type: ignore
|
||||
type = data.get('type', 1)
|
||||
if type != 1:
|
||||
# Context menu command...
|
||||
await self._call_context_menu(interaction, data, type)
|
||||
return
|
||||
|
||||
command, options = self._get_app_command_options(data)
|
||||
|
||||
# Pre-fill the cached slot to prevent re-computation
|
||||
interaction._cs_command = command
|
||||
|
||||
# At this point options refers to the arguments of the command
|
||||
# and command refers to the class type we care about
|
||||
namespace = Namespace(interaction, data.get('resolved', {}), options)
|
||||
|
||||
# Same pre-fill as above
|
||||
interaction._cs_namespace = namespace
|
||||
|
||||
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
|
||||
if interaction.type is InteractionType.autocomplete:
|
||||
with logging_context(action=f"Acmp {command.qualified_name}"):
|
||||
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
|
||||
if focused is None:
|
||||
raise AppCommandError(
|
||||
'This should not happen, but there is no focused element. This is a Discord bug.'
|
||||
)
|
||||
await command._invoke_autocomplete(interaction, focused, namespace)
|
||||
return
|
||||
|
||||
with logging_context(action=f"Run {command.qualified_name}"):
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
||||
try:
|
||||
await command._invoke_with_namespace(interaction, namespace)
|
||||
except AppCommandError as e:
|
||||
interaction.command_failed = True
|
||||
await command._invoke_error_handlers(interaction, e)
|
||||
await self.on_error(interaction, e)
|
||||
else:
|
||||
if not interaction.command_failed:
|
||||
self.client.dispatch('app_command_completion', interaction, command)
|
||||
finally:
|
||||
if interaction.command_failed:
|
||||
logger.debug("Command completed with errors.")
|
||||
else:
|
||||
logger.debug("Command completed without errors.")
|
||||
15
src/meta/__init__.py
Normal file
15
src/meta/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .LionBot import LionBot
|
||||
from .LionCog import LionCog
|
||||
from .LionContext import LionContext
|
||||
from .LionTree import LionTree
|
||||
|
||||
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||
from .config import conf, configEmoji
|
||||
from .args import args
|
||||
from .app import appname, shard_talk, appname_from_shard, shard_from_appname
|
||||
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||
from .context import context, ctx_bot
|
||||
|
||||
from . import sharding
|
||||
from . import logger
|
||||
from . import app
|
||||
46
src/meta/app.py
Normal file
46
src/meta/app.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
appname: str
|
||||
The base identifer for this application.
|
||||
This identifies which services the app offers.
|
||||
shardname: str
|
||||
The specific name of the running application.
|
||||
Only one process should be connecteded with a given appname.
|
||||
For the bot apps, usually specifies the shard id and shard number.
|
||||
"""
|
||||
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||
|
||||
from . import sharding, conf
|
||||
from .logger import log_app
|
||||
from .ipc.client import AppClient
|
||||
from .args import args
|
||||
|
||||
|
||||
appname = conf.data['appid']
|
||||
appid = appname # backwards compatibility
|
||||
|
||||
|
||||
def appname_from_shard(shardid):
|
||||
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||
return appname
|
||||
|
||||
|
||||
def shard_from_appname(appname: str):
|
||||
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||
|
||||
|
||||
shardname = appname_from_shard(sharding.shard_number)
|
||||
|
||||
log_app.set(shardname)
|
||||
|
||||
|
||||
shard_talk = AppClient(
|
||||
shardname,
|
||||
appname,
|
||||
{'host': args.host, 'port': args.port},
|
||||
{'host': conf.appipc['server_host'], 'port': int(conf.appipc['server_port'])}
|
||||
)
|
||||
|
||||
|
||||
@shard_talk.register_route()
|
||||
async def ping():
|
||||
return "Pong!"
|
||||
35
src/meta/args.py
Normal file
35
src/meta/args.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import argparse
|
||||
|
||||
from constants import CONFIG_FILE
|
||||
|
||||
# ------------------------------
|
||||
# Parsed commandline arguments
|
||||
# ------------------------------
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--conf',
|
||||
dest='config',
|
||||
default=CONFIG_FILE,
|
||||
help="Path to configuration file."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--shard',
|
||||
dest='shard',
|
||||
default=None,
|
||||
type=int,
|
||||
help="Shard number to run, if applicable."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--host',
|
||||
dest='host',
|
||||
default='127.0.0.1',
|
||||
help="IP address to run the app listener on."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
dest='port',
|
||||
default='5001',
|
||||
help="Port to run the app listener on."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
137
src/meta/config.py
Normal file
137
src/meta/config.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from discord import PartialEmoji
|
||||
import configparser as cfgp
|
||||
|
||||
from .args import args
|
||||
|
||||
|
||||
class configEmoji(PartialEmoji):
|
||||
__slots__ = ('fallback',)
|
||||
|
||||
def __init__(self, *args, fallback=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fallback = fallback
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, emojistr: str):
|
||||
"""
|
||||
Parses emoji strings of one of the following forms
|
||||
`<a:name:id> or fallback`
|
||||
`<:name:id> or fallback`
|
||||
`<a:name:id>`
|
||||
`<:name:id>`
|
||||
"""
|
||||
splits = emojistr.rsplit(' or ', maxsplit=1)
|
||||
|
||||
fallback = splits[1] if len(splits) > 1 else None
|
||||
emojistr = splits[0].strip('<> ')
|
||||
animated, name, id = emojistr.split(':')
|
||||
return cls(
|
||||
name=name,
|
||||
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||
animated=bool(animated),
|
||||
id=int(id) if id else None
|
||||
)
|
||||
|
||||
|
||||
class MapDotProxy:
|
||||
"""
|
||||
Allows dot access to an underlying Mappable object.
|
||||
"""
|
||||
__slots__ = ("_map", "_converter")
|
||||
|
||||
def __init__(self, mappable, converter=None):
|
||||
self._map = mappable
|
||||
self._converter = converter
|
||||
|
||||
def __getattribute__(self, key):
|
||||
_map = object.__getattribute__(self, '_map')
|
||||
if key == '_map':
|
||||
return _map
|
||||
if key in _map:
|
||||
_converter = object.__getattribute__(self, '_converter')
|
||||
if _converter:
|
||||
return _converter(_map[key])
|
||||
else:
|
||||
return _map[key]
|
||||
else:
|
||||
return object.__getattribute__(_map, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._map.__getitem__(key)
|
||||
|
||||
|
||||
class ConfigParser(cfgp.ConfigParser):
|
||||
"""
|
||||
Extension of base ConfigParser allowing optional
|
||||
section option retrieval without defaults.
|
||||
"""
|
||||
def options(self, section, no_defaults=False, **kwargs):
|
||||
if no_defaults:
|
||||
try:
|
||||
return list(self._sections[section].keys())
|
||||
except KeyError:
|
||||
raise cfgp.NoSectionError(section)
|
||||
else:
|
||||
return super().options(section, **kwargs)
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
"emoji": configEmoji.from_str,
|
||||
}
|
||||
)
|
||||
self.config.read(configfile)
|
||||
|
||||
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||
|
||||
self.default = self.config["DEFAULT"]
|
||||
self.section = MapDotProxy(self.config[self.section_name])
|
||||
self.bot = self.section
|
||||
|
||||
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||
read = set()
|
||||
while more_to_read:
|
||||
to_read = more_to_read.pop(0)
|
||||
read.add(to_read)
|
||||
self.config.read(to_read)
|
||||
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||
if path not in read and path not in more_to_read]
|
||||
more_to_read.extend(new_paths)
|
||||
|
||||
self.emojis = MapDotProxy(
|
||||
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
|
||||
converter=configEmoji.from_str
|
||||
)
|
||||
|
||||
global conf
|
||||
conf = self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
return self.config[section.upper()]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
return result.strip() if result else result
|
||||
|
||||
def _getintlist(self, value):
|
||||
return [int(item.strip()) for item in value.split(',')]
|
||||
|
||||
def _getlist(self, value):
|
||||
return [item.strip() for item in value.split(',')]
|
||||
|
||||
def write(self):
|
||||
with open(self.configfile, 'w') as conffile:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config, 'STUDYLION')
|
||||
20
src/meta/context.py
Normal file
20
src/meta/context.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Namespace for various global context variables.
|
||||
Allows asyncio callbacks to accurately retrieve information about the current state.
|
||||
"""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .LionBot import LionBot
|
||||
from .LionContext import LionContext
|
||||
|
||||
|
||||
# Contains the current command context, if applicable
|
||||
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||
|
||||
# Contains the current LionBot instance
|
||||
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||
64
src/meta/errors.py
Normal file
64
src/meta/errors.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
class SafeCancellation(Exception):
|
||||
"""
|
||||
Raised to safely cancel execution of the current operation.
|
||||
|
||||
If not caught, is expected to be propagated to the Tree and safely ignored there.
|
||||
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
|
||||
The error handler should then set the `msg` to None, to avoid double handling.
|
||||
Debugging information should go in `details`, to be logged by a top-level error handler.
|
||||
"""
|
||||
default_message = ""
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||
self._msg: Optional[str] = _msg
|
||||
self.details: str = details if details is not None else self.msg
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class UserInputError(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from unparseable user input.
|
||||
"""
|
||||
default_message = "Could not understand your input."
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||
|
||||
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||
self.info = info
|
||||
super().__init__(_msg, **kwargs)
|
||||
|
||||
|
||||
class UserCancelled(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from manual user cancellation.
|
||||
|
||||
Usually silent.
|
||||
"""
|
||||
default_msg = None
|
||||
|
||||
|
||||
class ResponseTimedOut(SafeCancellation):
|
||||
"""
|
||||
A SafeCancellation induced from a user interaction time-out.
|
||||
"""
|
||||
default_msg = "Session timed out waiting for input."
|
||||
|
||||
|
||||
class HandledException(SafeCancellation):
|
||||
"""
|
||||
Sentinel class to indicate to error handlers that this exception has been handled.
|
||||
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
|
||||
"""
|
||||
def __init__(self, exc=None, **kwargs):
|
||||
self.exc = exc
|
||||
super().__init__(**kwargs)
|
||||
2
src/meta/ipc/__init__.py
Normal file
2
src/meta/ipc/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .client import AppClient, AppPayload, AppRoute
|
||||
from .server import AppServer
|
||||
236
src/meta/ipc/client.py
Normal file
236
src/meta/ipc/client.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from typing import Optional, TypeAlias, Any
|
||||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from ..logger import logging_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Address: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class AppClient:
|
||||
routes: dict[str, 'AppRoute'] = {} # route_name -> Callable[Any, Awaitable[Any]]
|
||||
|
||||
def __init__(self, appid: str, basename: str, client_address: Address, server_address: Address):
|
||||
self.appid = appid # String identifier for this ShardTalk client
|
||||
self.basename = basename # Prefix used to recognise app peers
|
||||
self.address = client_address
|
||||
self.server_address = server_address
|
||||
|
||||
self.peers = {appid: client_address} # appid -> address
|
||||
|
||||
self._listener: Optional[asyncio.Server] = None # Local client server
|
||||
self._server = None # Connection to the registry server
|
||||
self._keepalive = None
|
||||
|
||||
self.register_route('new_peer')(self.new_peer)
|
||||
self.register_route('drop_peer')(self.drop_peer)
|
||||
self.register_route('peer_list')(self.peer_list)
|
||||
|
||||
@property
|
||||
def my_peers(self):
|
||||
return {peerid: peer for peerid, peer in self.peers.items() if peerid.startswith(self.basename)}
|
||||
|
||||
def register_route(self, name=None):
|
||||
def wrapper(coro):
|
||||
route = AppRoute(coro, client=self, name=name)
|
||||
self.routes[route.name] = route
|
||||
return route
|
||||
return wrapper
|
||||
|
||||
async def server_connection(self):
|
||||
"""Establish a connection to the registry server"""
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(**self.server_address)
|
||||
|
||||
payload = ('connect', (), {'appid': self.appid, 'address': self.address})
|
||||
writer.write(pickle.dumps(payload))
|
||||
writer.write(b'\n')
|
||||
await writer.drain()
|
||||
|
||||
data = await reader.readline()
|
||||
peers = pickle.loads(data)
|
||||
self.peers = peers
|
||||
self._server = (reader, writer)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Could not connect to registry server. Trying again in 30 seconds.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
else:
|
||||
logger.debug(
|
||||
"Connected to the registry server, launching keepalive.",
|
||||
extra={'action': 'Connect'}
|
||||
)
|
||||
self._keepalive = asyncio.create_task(self._server_keepalive())
|
||||
|
||||
async def _server_keepalive(self):
|
||||
with logging_context(action='Keepalive'):
|
||||
if self._server is None:
|
||||
raise ValueError("Cannot keepalive non-existent server!")
|
||||
reader, write = self._server
|
||||
try:
|
||||
await reader.read()
|
||||
except Exception:
|
||||
logger.exception("Lost connection to address server. Reconnecting...")
|
||||
else:
|
||||
# Connection ended or broke
|
||||
logger.info("Lost connection to address server. Reconnecting...")
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
|
||||
async def new_peer(self, appid, address):
|
||||
self.peers[appid] = address
|
||||
|
||||
async def peer_list(self, peers):
|
||||
self.peers = peers
|
||||
|
||||
async def drop_peer(self, appid):
|
||||
self.peers.pop(appid, None)
|
||||
|
||||
async def close(self):
|
||||
# Close connection to the server
|
||||
# TODO
|
||||
...
|
||||
|
||||
async def request(self, appid, payload: 'AppPayload', wait_for_reply=True):
|
||||
with logging_context(action=f"Req {appid}"):
|
||||
try:
|
||||
if appid not in self.peers:
|
||||
raise ValueError(f"Peer '{appid}' not found.")
|
||||
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
|
||||
|
||||
address = self.peers[appid]
|
||||
reader, writer = await asyncio.open_connection(**address)
|
||||
|
||||
writer.write(payload.encoded())
|
||||
await writer.drain()
|
||||
writer.write_eof()
|
||||
if wait_for_reply:
|
||||
result = await reader.read()
|
||||
writer.close()
|
||||
decoded = payload.route.decode(result)
|
||||
return decoded
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logging.exception(f"Failed to send request to {appid}'")
|
||||
return None
|
||||
|
||||
async def requestall(self, payload, except_self=True, only_my_peers=True):
|
||||
with logging_context(action="Broadcast"):
|
||||
peerlist = self.my_peers if only_my_peers else self.peers
|
||||
results = await asyncio.gather(
|
||||
*(self.request(appid, payload) for appid in peerlist if (appid != self.appid or not except_self)),
|
||||
return_exceptions=True
|
||||
)
|
||||
return dict(zip(self.peers.keys(), results))
|
||||
|
||||
async def handle_request(self, reader, writer):
|
||||
with logging_context(action="SERV"):
|
||||
data = await reader.read()
|
||||
loaded = pickle.loads(data)
|
||||
route, args, kwargs = loaded
|
||||
|
||||
with logging_context(action=f"SERV {route}"):
|
||||
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
try:
|
||||
await self.routes[route].run((reader, writer), args, kwargs)
|
||||
except Exception:
|
||||
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
|
||||
else:
|
||||
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
|
||||
writer.write_eof()
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Start the local peer server.
|
||||
Connect to the address server.
|
||||
"""
|
||||
with logging_context(stack=['ShardTalk']):
|
||||
# Start the client server
|
||||
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
||||
|
||||
logger.info(f"Serving on {self.address}")
|
||||
await self.server_connection()
|
||||
|
||||
|
||||
class AppPayload:
|
||||
__slots__ = ('route', 'args', 'kwargs')
|
||||
|
||||
def __init__(self, route, *args, **kwargs):
|
||||
self.route = route
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __await__(self):
|
||||
return self.route.execute(*self.args, **self.kwargs).__await__()
|
||||
|
||||
def encoded(self):
|
||||
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
||||
|
||||
async def send(self, appid, **kwargs):
|
||||
return await self.route._client.request(appid, self, **kwargs)
|
||||
|
||||
async def broadcast(self, **kwargs):
|
||||
return await self.route._client.requestall(self, **kwargs)
|
||||
|
||||
|
||||
class AppRoute:
|
||||
__slots__ = ('func', 'name', '_client')
|
||||
|
||||
def __init__(self, func, client=None, name=None):
|
||||
self.func = func
|
||||
self.name = name or func.__name__
|
||||
self._client = client
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AppPayload(self, *args, **kwargs)
|
||||
|
||||
def encode(self, output):
|
||||
return pickle.dumps(output)
|
||||
|
||||
def decode(self, encoded):
|
||||
# TODO: Handle exceptions here somehow
|
||||
if len(encoded) > 0:
|
||||
return pickle.loads(encoded)
|
||||
else:
|
||||
return ''
|
||||
|
||||
def encoder(self, func):
|
||||
self.encode = func
|
||||
|
||||
def decoder(self, func):
|
||||
self.decode = func
|
||||
|
||||
async def execute(self, *args, **kwargs):
|
||||
"""
|
||||
Execute the underlying function, with the given arguments.
|
||||
"""
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
async def run(self, connection, args, kwargs):
|
||||
"""
|
||||
Run the route, with the given arguments, using the given connection.
|
||||
"""
|
||||
# TODO: ContextVar here for logging? Or in handle_request?
|
||||
# Get encoded result
|
||||
# TODO: handle exceptions in the execution process
|
||||
try:
|
||||
result = await self.execute(*args, **kwargs)
|
||||
payload = self.encode(result)
|
||||
except Exception:
|
||||
logger.exception(f"Exception occured running route '{self.name}' with args: {args} and kwargs: {kwargs}")
|
||||
payload = b''
|
||||
_, writer = connection
|
||||
writer.write(payload)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
175
src/meta/ipc/server.py
Normal file
175
src/meta/ipc/server.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
import logging
|
||||
import string
|
||||
import random
|
||||
|
||||
from ..logger import log_context, log_app, logging_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
uuid_alphabet = string.ascii_lowercase + string.digits
|
||||
|
||||
|
||||
def short_uuid():
|
||||
return ''.join(random.choices(uuid_alphabet, k=10))
|
||||
|
||||
|
||||
class AppServer:
|
||||
routes = {} # route name -> bound method
|
||||
|
||||
def __init__(self):
|
||||
self.clients = {} # AppID -> (info, connection)
|
||||
|
||||
self.route('ping')(self.route_ping)
|
||||
self.route('whereis')(self.route_whereis)
|
||||
self.route('peers')(self.route_peers)
|
||||
self.route('connect')(self.client_connection)
|
||||
|
||||
@classmethod
|
||||
def route(cls, route_name):
|
||||
"""
|
||||
Decorator to add a route to the server.
|
||||
"""
|
||||
def wrapper(coro):
|
||||
cls.routes[route_name] = coro
|
||||
return coro
|
||||
return wrapper
|
||||
|
||||
async def route_ping(self, connection):
|
||||
"""
|
||||
Pong.
|
||||
"""
|
||||
reader, writer = connection
|
||||
writer.write(b"Pong")
|
||||
writer.write_eof()
|
||||
|
||||
async def route_whereis(self, connection, appid):
|
||||
"""
|
||||
Return an address for the given client appid.
|
||||
Returns None if the client does not have a connection.
|
||||
"""
|
||||
reader, writer = connection
|
||||
if appid in self.clients:
|
||||
writer.write(pickle.dumps(self.clients[appid][0]))
|
||||
else:
|
||||
writer.write(b'')
|
||||
writer.write_eof()
|
||||
|
||||
async def route_peers(self, connection):
|
||||
"""
|
||||
Send back a map of current peers.
|
||||
"""
|
||||
reader, writer = connection
|
||||
peers = self.peer_list()
|
||||
payload = pickle.dumps(('peer_list', (peers,)))
|
||||
writer.write(payload)
|
||||
writer.write_eof()
|
||||
|
||||
async def client_connection(self, connection, appid, address):
|
||||
"""
|
||||
Register and hold a new client connection.
|
||||
"""
|
||||
with logging_context(action=f"CONN {appid}"):
|
||||
reader, writer = connection
|
||||
# Add the new client
|
||||
self.clients[appid] = (address, connection)
|
||||
|
||||
# Send the new client a client list
|
||||
peers = self.peer_list()
|
||||
writer.write(pickle.dumps(peers))
|
||||
writer.write(b'\n')
|
||||
await writer.drain()
|
||||
|
||||
# Announce the new client to everyone
|
||||
await self.broadcast('new_peer', (), {'appid': appid, 'address': address})
|
||||
|
||||
# Keep the connection open until socket closed or EOF (indicating client death)
|
||||
try:
|
||||
await reader.read()
|
||||
finally:
|
||||
# Connection ended or it broke
|
||||
logger.info(f"Lost client '{appid}'")
|
||||
await self.deregister_client(appid)
|
||||
|
||||
async def handle_connection(self, reader, writer):
|
||||
data = await reader.readline()
|
||||
route, args, kwargs = pickle.loads(data)
|
||||
|
||||
rqid = short_uuid()
|
||||
|
||||
with logging_context(context=f"RQID: {rqid}", action=f"ROUTE {route}"):
|
||||
logger.info(f"AppServer handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
# Execute route
|
||||
try:
|
||||
await self.routes[route]((reader, writer), *args, **kwargs)
|
||||
except Exception:
|
||||
logger.exception(f"AppServer recieved exception during route '{route}'")
|
||||
else:
|
||||
logger.warning(f"AppServer recieved unknown route '{route}'. Ignoring.")
|
||||
|
||||
def peer_list(self):
|
||||
return {appid: address for appid, (address, _) in self.clients.items()}
|
||||
|
||||
async def deregister_client(self, appid):
|
||||
self.clients.pop(appid, None)
|
||||
await self.broadcast('drop_peer', (), {'appid': appid})
|
||||
|
||||
async def broadcast(self, route, args, kwargs):
|
||||
with logging_context(action="broadcast"):
|
||||
logger.debug(f"Sending broadcast on route '{route}' with args {args} and kwargs {kwargs}.")
|
||||
payload = pickle.dumps((route, args, kwargs))
|
||||
if self.clients:
|
||||
await asyncio.gather(
|
||||
*(self._send(appid, payload) for appid in self.clients),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
async def message_client(self, appid, route, args, kwargs):
|
||||
"""
|
||||
Send a message to client `appid` along `route` with given arguments.
|
||||
"""
|
||||
with logging_context(action=f"MSG {appid}"):
|
||||
logger.debug(f"Sending '{route}' to '{appid}' with args {args} and kwargs {kwargs}.")
|
||||
if appid not in self.clients:
|
||||
raise ValueError(f"Client '{appid}' is not connected.")
|
||||
|
||||
payload = pickle.dumps((route, args, kwargs))
|
||||
return await self._send(appid, payload)
|
||||
|
||||
async def _send(self, appid, payload):
|
||||
"""
|
||||
Send the encoded `payload` to the client `appid`.
|
||||
"""
|
||||
address, _ = self.clients[appid]
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(**address)
|
||||
writer.write(payload)
|
||||
writer.write_eof()
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
except Exception as ex:
|
||||
# TODO: Close client if we can't connect?
|
||||
logger.exception(f"Failed to send message to '{appid}'")
|
||||
raise ex
|
||||
|
||||
async def start(self, address):
|
||||
log_app.set("APPSERVER")
|
||||
with logging_context(stack=["SERV"]):
|
||||
server = await asyncio.start_server(self.handle_connection, **address)
|
||||
logger.info(f"Serving on {address}")
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
async def start_server():
|
||||
address = {'host': '127.0.0.1', 'port': '5000'}
|
||||
server = AppServer()
|
||||
await server.start(address)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(start_server())
|
||||
324
src/meta/logger.py
Normal file
324
src/meta/logger.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List
|
||||
from logging.handlers import QueueListener, QueueHandler
|
||||
from queue import SimpleQueue
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
from discord import Webhook, File
|
||||
import aiohttp
|
||||
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
from .context import context
|
||||
from utils.lib import utc_now
|
||||
|
||||
|
||||
log_logger = logging.getLogger(__name__)
|
||||
log_logger.propagate = False
|
||||
|
||||
|
||||
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||
log_action_stack: ContextVar[List[str]] = ContextVar('logging_action_stack', default=[])
|
||||
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def logging_context(context=None, action=None, stack=None):
|
||||
if context is not None:
|
||||
context_t = log_context.set(context)
|
||||
if action is not None:
|
||||
astack = log_action_stack.get()
|
||||
log_action_stack.set(astack + [action])
|
||||
if stack is not None:
|
||||
actions_t = log_action_stack.set(stack)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if context is not None:
|
||||
log_context.reset(context_t)
|
||||
if stack is not None:
|
||||
log_action_stack.reset(actions_t)
|
||||
if action is not None:
|
||||
log_action_stack.set(astack)
|
||||
|
||||
|
||||
def log_wrap(**kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapped(*w_args, **w_kwargs):
|
||||
with logging_context(**kwargs):
|
||||
return await func(*w_args, **w_kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
"]]]"
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
|
||||
|
||||
def colour_escape(fmt: str) -> str:
|
||||
cmap = {
|
||||
'%(black)': COLOR_SEQ % BLACK,
|
||||
'%(red)': COLOR_SEQ % RED,
|
||||
'%(green)': COLOR_SEQ % GREEN,
|
||||
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||
'%(blue)': COLOR_SEQ % BLUE,
|
||||
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||
'%(cyan)': COLOR_SEQ % CYAN,
|
||||
'%(white)': COLOR_SEQ % WHITE,
|
||||
'%(reset)': RESET_SEQ,
|
||||
'%(bold)': BOLD_SEQ,
|
||||
}
|
||||
for key, value in cmap.items():
|
||||
fmt = fmt.replace(key, value)
|
||||
return fmt
|
||||
|
||||
|
||||
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
|
||||
'[%(cyan)%(app)-15s%(reset)]' +
|
||||
'[%(cyan)%(context)-24s%(reset)]' +
|
||||
'[%(cyan)%(actionstr)-22s%(reset)]' +
|
||||
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||
log_format = colour_escape(log_format)
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=log_format,
|
||||
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
|
||||
class ThreadFilter(logging.Filter):
|
||||
def __init__(self, thread_name):
|
||||
super().__init__("")
|
||||
self.thread = thread_name
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.threadName == self.thread else 0
|
||||
|
||||
|
||||
class ContextInjection(logging.Filter):
|
||||
def filter(self, record):
|
||||
# These guards are to allow override through _extra
|
||||
# And to ensure the injection is idempotent
|
||||
if not hasattr(record, 'context'):
|
||||
record.context = log_context.get()
|
||||
|
||||
if not hasattr(record, 'actionstr'):
|
||||
action_stack = log_action_stack.get()
|
||||
if hasattr(record, 'action'):
|
||||
action_stack = (*action_stack, record.action)
|
||||
if action_stack:
|
||||
record.actionstr = ' ➔ '.join(action_stack)
|
||||
else:
|
||||
record.actionstr = "Unknown Action"
|
||||
|
||||
if not hasattr(record, 'app'):
|
||||
record.app = log_app.get()
|
||||
|
||||
if not hasattr(record, 'ctx'):
|
||||
if ctx := context.get():
|
||||
record.ctx = repr(ctx)
|
||||
else:
|
||||
record.ctx = None
|
||||
|
||||
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||
record.ctxstr = '\n' + record.ctx
|
||||
else:
|
||||
record.ctxstr = ""
|
||||
return True
|
||||
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||
logging_handler_out.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_out)
|
||||
log_logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logging_handler_err.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_err)
|
||||
log_logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
class LocalQueueHandler(QueueHandler):
|
||||
def _emit(self, record: logging.LogRecord) -> None:
|
||||
# Removed the call to self.prepare(), handle task cancellation
|
||||
try:
|
||||
self.enqueue(record)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class WebHookHandler(logging.StreamHandler):
|
||||
def __init__(self, webhook_url, batch=False, loop=None):
|
||||
super().__init__()
|
||||
self.webhook_url = webhook_url
|
||||
self.batched = ""
|
||||
self.batch = batch
|
||||
self.loop = loop
|
||||
self.batch_delay = 10
|
||||
self.batch_task = None
|
||||
self.last_batched = None
|
||||
self.waiting = []
|
||||
|
||||
def get_loop(self):
|
||||
if self.loop is None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
return self.loop
|
||||
|
||||
def emit(self, record):
|
||||
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||
|
||||
def _post(self, record):
|
||||
asyncio.create_task(self.post(record))
|
||||
|
||||
async def post(self, record):
|
||||
log_context.set("Webhook Logger")
|
||||
log_action_stack.set(["Logging"])
|
||||
log_app.set(record.app)
|
||||
|
||||
try:
|
||||
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||
message = f"{header}\n{record.msg}{context}"
|
||||
|
||||
if len(message) > 1900:
|
||||
as_file = True
|
||||
else:
|
||||
as_file = False
|
||||
message = "```md\n{}\n```".format(message)
|
||||
|
||||
# Post the log message(s)
|
||||
if self.batch:
|
||||
if len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
await self._send(message, as_file=as_file)
|
||||
else:
|
||||
self.batched += message
|
||||
if len(self.batched) + len(message) > 1500:
|
||||
await self._send_batched_now()
|
||||
else:
|
||||
asyncio.create_task(self._schedule_batched())
|
||||
else:
|
||||
await self._send(message, as_file=as_file)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
async def _schedule_batched(self):
|
||||
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||
# noop, don't reschedule if it is already scheduled
|
||||
return
|
||||
try:
|
||||
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||
await self.batch_task
|
||||
await self._send_batched()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
||||
async def _send_batched_now(self):
|
||||
if self.batch_task is not None and not self.batch_task.done():
|
||||
self.batch_task.cancel()
|
||||
self.last_batched = None
|
||||
await self._send_batched()
|
||||
|
||||
async def _send_batched(self):
|
||||
if self.batched:
|
||||
batched = self.batched
|
||||
self.batched = ""
|
||||
await self._send(batched)
|
||||
|
||||
async def _send(self, message, as_file=False):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
webhook = Webhook.from_url(self.webhook_url, session=session)
|
||||
if as_file or len(message) > 2000:
|
||||
with StringIO(message) as fp:
|
||||
fp.seek(0)
|
||||
await webhook.send(
|
||||
file=File(fp, filename="logs.md"),
|
||||
username=log_app.get()
|
||||
)
|
||||
else:
|
||||
await webhook.send(message, username=log_app.get())
|
||||
|
||||
|
||||
handlers = []
|
||||
if webhook := conf.logging['general_log']:
|
||||
handler = WebHookHandler(webhook, batch=True)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['error_log']:
|
||||
handler = WebHookHandler(webhook, batch=False)
|
||||
handler.setLevel(logging.ERROR)
|
||||
handlers.append(handler)
|
||||
|
||||
if webhook := conf.logging['critical_log']:
|
||||
handler = WebHookHandler(webhook, batch=False)
|
||||
handler.setLevel(logging.CRITICAL)
|
||||
handlers.append(handler)
|
||||
|
||||
if handlers:
|
||||
# First create a separate loop to run the handlers on
|
||||
import threading
|
||||
|
||||
def run_loop(loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_forever()
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||
loop_thread.daemon = True
|
||||
loop_thread.start()
|
||||
|
||||
for handler in handlers:
|
||||
handler.loop = loop
|
||||
|
||||
queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
|
||||
|
||||
qhandler = QueueHandler(queue)
|
||||
qhandler.setLevel(logging.INFO)
|
||||
qhandler.addFilter(ContextInjection())
|
||||
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||
logger.addHandler(qhandler)
|
||||
|
||||
listener = QueueListener(
|
||||
queue, *handlers, respect_handler_level=True
|
||||
)
|
||||
listener.start()
|
||||
34
src/meta/pending-rewrite/client.py
Normal file
34
src/meta/pending-rewrite/client.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from discord import Intents
|
||||
from cmdClient.cmdClient import cmdClient
|
||||
|
||||
from . import patches
|
||||
from .interactions import InteractionType
|
||||
from .config import conf
|
||||
from .sharding import shard_number, shard_count
|
||||
from LionContext import LionContext
|
||||
|
||||
|
||||
# Initialise client
|
||||
owners = [int(owner) for owner in conf.bot.getlist('owners')]
|
||||
intents = Intents.all()
|
||||
intents.presences = False
|
||||
client = cmdClient(
|
||||
prefix=conf.bot['prefix'],
|
||||
owners=owners,
|
||||
intents=intents,
|
||||
shard_id=shard_number,
|
||||
shard_count=shard_count,
|
||||
baseContext=LionContext
|
||||
)
|
||||
client.conf = conf
|
||||
|
||||
|
||||
# TODO: Could include client id here, or app id, to avoid multiple handling.
|
||||
NOOP_ID = 'NOOP'
|
||||
|
||||
|
||||
@client.add_after_event('interaction_create')
|
||||
async def handle_noop_interaction(client, interaction):
|
||||
if interaction.interaction_type in (InteractionType.MESSAGE_COMPONENT, InteractionType.MODAL_SUBMIT):
|
||||
if interaction.custom_id == NOOP_ID:
|
||||
interaction.ack()
|
||||
110
src/meta/pending-rewrite/logger.py
Normal file
110
src/meta/pending-rewrite/logger.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from discord import AllowedMentions
|
||||
|
||||
from cmdClient.logger import cmd_log_handler
|
||||
|
||||
from utils.lib import mail, split_text
|
||||
|
||||
from .client import client
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=('[{asctime}][{levelname:^8}]' +
|
||||
'[SHARD {}]'.format(sharding.shard_number) +
|
||||
' {message}'),
|
||||
datefmt='%d/%m | %H:%M:%S',
|
||||
style='{'
|
||||
)
|
||||
# term_handler = logging.StreamHandler(sys.stdout)
|
||||
# term_handler.setFormatter(log_fmt)
|
||||
# logger.addHandler(term_handler)
|
||||
# logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||
logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logger.addHandler(logging_handler_err)
|
||||
|
||||
|
||||
# Define the context log format and attach it to the command logger as well
|
||||
@cmd_log_handler
|
||||
def log(message, context="GLOBAL", level=logging.INFO, post=True):
|
||||
# Add prefixes to lines for better parsing capability
|
||||
lines = message.splitlines()
|
||||
if len(lines) > 1:
|
||||
lines = [
|
||||
'┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line
|
||||
for i, line in enumerate(lines)
|
||||
]
|
||||
else:
|
||||
lines = ['─ ' + message]
|
||||
|
||||
for line in lines:
|
||||
logger.log(level, '\b[{}] {}'.format(
|
||||
str(context).center(22, '='),
|
||||
line
|
||||
))
|
||||
|
||||
# Fire and forget to the channel logger, if it is set up
|
||||
if post and client.is_ready():
|
||||
asyncio.ensure_future(live_log(message, context, level))
|
||||
|
||||
|
||||
# Live logger that posts to the logging channels
|
||||
async def live_log(message, context, level):
|
||||
if level >= logging.INFO:
|
||||
if level >= logging.WARNING:
|
||||
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
|
||||
else:
|
||||
log_chid = conf.bot.getint('log_channel')
|
||||
|
||||
# Generate the log messages
|
||||
if sharding.sharded:
|
||||
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
|
||||
else:
|
||||
header = f"[{logging.getLevelName(level)}][{context}]"
|
||||
|
||||
if len(message) > 1900:
|
||||
blocks = split_text(message, blocksize=1900, code=False)
|
||||
else:
|
||||
blocks = [message]
|
||||
|
||||
if len(blocks) > 1:
|
||||
blocks = [
|
||||
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
|
||||
]
|
||||
else:
|
||||
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
|
||||
|
||||
# Post the log messages
|
||||
if log_chid:
|
||||
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
|
||||
|
||||
|
||||
# Attach logger to client, for convenience
|
||||
client.log = log
|
||||
35
src/meta/sharding.py
Normal file
35
src/meta/sharding.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from .args import args
|
||||
from .config import conf
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
|
||||
|
||||
shard_number = args.shard or 0
|
||||
|
||||
shard_count = conf.bot.getint('shard_count', 1)
|
||||
|
||||
sharded = (shard_count > 0)
|
||||
|
||||
|
||||
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by shard id.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(_shard_condition('guildid', 10, 1))
|
||||
"""
|
||||
return Condition(
|
||||
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
shard_count=sql.Literal(shard_count)
|
||||
),
|
||||
Joiner.EQUALS,
|
||||
sql.Placeholder(),
|
||||
(shard_id,)
|
||||
)
|
||||
|
||||
|
||||
# Pre-built Condition for filtering by current shard.
|
||||
THIS_SHARD = SHARDID(shard_number)
|
||||
16
src/modules/__init__.py
Normal file
16
src/modules/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
this_package = 'modules'
|
||||
|
||||
active = [
|
||||
'.sysadmin',
|
||||
'.config',
|
||||
'.economy',
|
||||
'.reminders',
|
||||
'.shop',
|
||||
'.tasklist',
|
||||
'.test',
|
||||
]
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
for ext in active:
|
||||
await bot.load_extension(ext, package=this_package)
|
||||
10
src/modules/config/__init__.py
Normal file
10
src/modules/config/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import logging
|
||||
from babel.translator import LocalBabel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
babel = LocalBabel('config')
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import ConfigCog
|
||||
await bot.add_cog(ConfigCog(bot))
|
||||
30
src/modules/config/cog.py
Normal file
30
src/modules/config/cog.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import discord
|
||||
from discord import app_commands as appcmds
|
||||
from discord.ext import commands as cmds
|
||||
|
||||
from meta import LionBot, LionContext, LionCog
|
||||
|
||||
from . import babel
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class ConfigCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
|
||||
async def cog_load(self):
|
||||
...
|
||||
|
||||
async def cog_unload(self):
|
||||
...
|
||||
|
||||
@cmds.hybrid_group(
|
||||
name=_p('group:configure', "configure"),
|
||||
)
|
||||
@appcmds.guild_only
|
||||
async def configure_group(self, ctx: LionContext):
|
||||
"""
|
||||
Bare command group, has no function.
|
||||
"""
|
||||
return
|
||||
5
src/modules/economy/__init__.py
Normal file
5
src/modules/economy/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .cog import Economy
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(Economy(bot))
|
||||
969
src/modules/economy/cog.py
Normal file
969
src/modules/economy/cog.py
Normal file
@@ -0,0 +1,969 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from psycopg import sql
|
||||
from data import Registry, RowModel, RegisterEnum, ORDER, JOINTYPE, RawExpr
|
||||
from data.columns import Integer, Bool, String, Column, Timestamp
|
||||
|
||||
from meta import LionCog, LionBot, LionContext
|
||||
from meta.errors import ResponseTimedOut
|
||||
from babel import LocalBabel
|
||||
|
||||
from core.data import CoreData
|
||||
|
||||
from utils.ui import LeoUI, LeoModal, Confirm, Pager
|
||||
from utils.lib import error_embed, MessageArgs, utc_now
|
||||
|
||||
babel = LocalBabel('economy')
|
||||
_, _p, _np = babel._, babel._p, babel._np
|
||||
|
||||
|
||||
MAX_COINS = 2**16
|
||||
|
||||
|
||||
class TransactionType(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE CoinTransactionType AS ENUM(
|
||||
'REFUND',
|
||||
'TRANSFER',
|
||||
'SHOP_PURCHASE',
|
||||
'STUDY_SESSION',
|
||||
'ADMIN',
|
||||
'TASKS'
|
||||
);
|
||||
"""
|
||||
REFUND = 'REFUND',
|
||||
TRANSFER = 'TRANSFER',
|
||||
PURCHASE = 'SHOP_PURCHASE',
|
||||
SESSION = 'STUDY_SESSION',
|
||||
ADMIN = 'ADMIN',
|
||||
TASKS = 'TASKS',
|
||||
|
||||
|
||||
class AdminActionTarget(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE EconAdminTarget AS ENUM(
|
||||
'ROLE',
|
||||
'USER',
|
||||
'GUILD'
|
||||
);
|
||||
"""
|
||||
ROLE = 'ROLE',
|
||||
USER = 'USER',
|
||||
GUILD = 'GUILD',
|
||||
|
||||
|
||||
class AdminActionType(Enum):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TYPE EconAdminAction AS ENUM(
|
||||
'SET',
|
||||
'ADD'
|
||||
);
|
||||
"""
|
||||
SET = 'SET',
|
||||
ADD = 'ADD',
|
||||
|
||||
|
||||
class EconomyData(Registry, name='economy'):
|
||||
_TransactionType = RegisterEnum(TransactionType, 'CoinTransactionType')
|
||||
_AdminActionTarget = RegisterEnum(AdminActionTarget, 'EconAdminTarget')
|
||||
_AdminActionType = RegisterEnum(AdminActionType, 'EconAdminAction')
|
||||
|
||||
class Transaction(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE coin_transactions(
|
||||
transactionid SERIAL PRIMARY KEY,
|
||||
transactiontype CoinTransactionType NOT NULL,
|
||||
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
|
||||
actorid BIGINT NOT NULL,
|
||||
amount INTEGER NOT NULL,
|
||||
bonus INTEGER NOT NULL,
|
||||
from_account BIGINT,
|
||||
to_account BIGINT,
|
||||
refunds INTEGER REFERENCES coin_transactions (transactionid) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT (now() at time zone 'utc')
|
||||
);
|
||||
CREATE INDEX coin_transaction_guilds ON coin_transactions (guildid);
|
||||
"""
|
||||
_tablename_ = 'coin_transactions'
|
||||
|
||||
transactionid = Integer(primary=True)
|
||||
transactiontype: Column[TransactionType] = Column()
|
||||
guildid = Integer()
|
||||
actorid = Integer()
|
||||
amount = Integer()
|
||||
bonus = Integer()
|
||||
from_account = Integer()
|
||||
to_account = Integer()
|
||||
refunds = Integer()
|
||||
created_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
async def execute_transaction(
|
||||
cls,
|
||||
transaction_type: TransactionType,
|
||||
guildid: int, actorid: int,
|
||||
from_account: int, to_account: int, amount: int, bonus: int = 0,
|
||||
refunds: int = None
|
||||
):
|
||||
transaction = await cls.create(
|
||||
transactiontype=transaction_type,
|
||||
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
|
||||
from_account=from_account, to_account=to_account,
|
||||
refunds=refunds
|
||||
)
|
||||
if from_account is not None:
|
||||
await CoreData.Member.table.update_where(
|
||||
guildid=guildid, userid=from_account
|
||||
).set(coins=(CoreData.Member.coins - (amount + bonus)))
|
||||
if to_account is not None:
|
||||
await CoreData.Member.table.update_where(
|
||||
guildid=guildid, userid=to_account
|
||||
).set(coins=(CoreData.Member.coins + (amount + bonus)))
|
||||
return transaction
|
||||
|
||||
class ShopTransaction(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE coin_transactions_shop(
|
||||
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
|
||||
itemid INTEGER NOT NULL REFERENCES shop_items (itemid) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'coin_transactions_shop'
|
||||
|
||||
transactionid = Integer(primary=True)
|
||||
itemid = Integer()
|
||||
|
||||
@classmethod
|
||||
async def purchase_transaction(
|
||||
cls,
|
||||
guildid: int, actorid: int,
|
||||
userid: int, itemid: int, amount: int
|
||||
):
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.transaction():
|
||||
row = await EconomyData.Transaction.execute_transaction(
|
||||
TransactionType.PURCHASE,
|
||||
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
|
||||
amount=amount
|
||||
)
|
||||
return await cls.create(transactionid=row.transactionid, itemid=itemid)
|
||||
|
||||
class TaskTransaction(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE coin_transactions_tasks(
|
||||
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
|
||||
count INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'coin_transactions_tasks'
|
||||
|
||||
transactionid = Integer(primary=True)
|
||||
count = Integer()
|
||||
|
||||
@classmethod
|
||||
async def count_recent_for(cls, userid, guildid, interval='24h'):
|
||||
"""
|
||||
Retrieve the number of tasks rewarded in the last `interval`.
|
||||
"""
|
||||
T = EconomyData.Transaction
|
||||
query = cls.table.select_where().with_no_adapter()
|
||||
query.join(T, using=(T.transactionid.name, ), join_type=JOINTYPE.LEFT)
|
||||
query.select(recent=sql.SQL("SUM({})").format(cls.count.expr))
|
||||
query.where(
|
||||
T.to_account == userid,
|
||||
T.guildid == guildid,
|
||||
T.created_at > RawExpr(sql.SQL("timezone('utc', NOW()) - INTERVAL {}").format(interval), ()),
|
||||
)
|
||||
result = await query
|
||||
return result[0]['recent'] or 0
|
||||
|
||||
@classmethod
|
||||
async def reward_completed(cls, userid, guildid, count, amount):
|
||||
"""
|
||||
Reward the specified member `amount` coins for completing `count` tasks.
|
||||
"""
|
||||
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.transaction():
|
||||
row = await EconomyData.Transaction.execute_transaction(
|
||||
TransactionType.TASKS,
|
||||
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
|
||||
amount=amount
|
||||
)
|
||||
return await cls.create(transactionid=row.transactionid, count=count)
|
||||
|
||||
class SessionTransaction(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE coin_transactions_sessions(
|
||||
transactionid INTEGER PRIMARY KEY REFERENCES coin_transactions (transactionid) ON DELETE CASCADE,
|
||||
sessionid INTEGER NOT NULL REFERENCES session_history (sessionid) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'coin_transactions_sessions'
|
||||
|
||||
transactionid = Integer(primary=True)
|
||||
sessionid = Integer()
|
||||
|
||||
class AdminActions(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE economy_admin_actions(
|
||||
actionid SERIAL PRIMARY KEY,
|
||||
target_type EconAdminTarget NOT NULL,
|
||||
action_type EconAdminAction NOT NULL,
|
||||
targetid INTEGER NOT NULL,
|
||||
amount INTEGER NOT NULL
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'economy_admin_actions'
|
||||
|
||||
actionid = Integer(primary=True)
|
||||
target_type: Column[AdminActionTarget] = Column()
|
||||
action_type: Column[AdminActionType] = Column()
|
||||
targetid = Integer()
|
||||
amount = Integer()
|
||||
|
||||
class AdminTransactions(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE coin_transactions_admin_actions(
|
||||
actionid INTEGER NOT NULL REFERENCES economy_admin_actions (actionid),
|
||||
transactionid INTEGER NOT NULL REFERENCES coin_transactions (transactionid),
|
||||
PRIMARY KEY (actionid, transactionid)
|
||||
);
|
||||
CREATE INDEX coin_transactions_admin_actions_transactionid ON coin_transactions_admin_actions (transactionid);
|
||||
"""
|
||||
_tablename_ = 'coin_transactions_admin_actions'
|
||||
|
||||
actionid = Integer(primary=True)
|
||||
transactionid = Integer(primary=True)
|
||||
|
||||
|
||||
class Economy(LionCog):
|
||||
"""
|
||||
Commands
|
||||
--------
|
||||
/economy balances [target:<mentionable>] [add:<int>] [set:<int>].
|
||||
With no arguments, show a summary of current balances in the server.
|
||||
With a target user or role, show their balance, and possibly their most recent transactions.
|
||||
With a target user or role, and add or set, modify their balance. Confirm if more than 1 user is affected.
|
||||
With no target user or role, apply to everyone in the guild. Confirm if more than 1 user affected.
|
||||
|
||||
/economy reset [target:<mentionable>]
|
||||
Reset the economy system for the given target, or everyone in the guild.
|
||||
Acts as an alias to `/economy balances target:target set:0
|
||||
|
||||
/economy history [target:<mentionable>]
|
||||
Display a paged audit trail with the history of the selected member,
|
||||
all the users in the selected role, or all users.
|
||||
|
||||
/sendcoins <user:<user>> [note:<str>]
|
||||
Send coins to the specified user, with an optional note.
|
||||
"""
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(EconomyData())
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
# ----- Economy group commands -----
|
||||
|
||||
@cmds.hybrid_group(name=_p('cmd:economy', "economy"))
|
||||
@cmds.guild_only()
|
||||
async def economy_group(self, ctx: LionContext):
|
||||
pass
|
||||
|
||||
@economy_group.command(
|
||||
name=_p('cmd:economy_balance', "balance"),
|
||||
description=_p(
|
||||
'cmd:economy_balance|desc',
|
||||
"Display and modify LionCoin balance for members or roles."
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
target=_p('cmd:economy_balance|param:target', "target"),
|
||||
add=_p('cmd:economy_balance|param:add', "add"),
|
||||
set_to=_p('cmd:economy_balance|param:set', "set")
|
||||
)
|
||||
@appcmds.describe(
|
||||
target=_p(
|
||||
'cmd:economy_balance|param:target|desc',
|
||||
"Target user or role to view or update. Use @everyone to update the entire guild."
|
||||
),
|
||||
add=_p(
|
||||
'cmd:economy_balance|param:add|desc',
|
||||
"Number of LionCoins to add to the target member's balance. May be negative to remove."
|
||||
),
|
||||
set_to=_p(
|
||||
'cmd:economy_balance|param:set|set',
|
||||
"New balance to set the target's balance to."
|
||||
)
|
||||
)
|
||||
async def economy_balance_cmd(
|
||||
self,
|
||||
ctx: LionContext,
|
||||
target: discord.User | discord.Member | discord.Role,
|
||||
set_to: Optional[appcmds.Range[int, 0, MAX_COINS]] = None,
|
||||
add: Optional[int] = None
|
||||
):
|
||||
t = self.bot.translator.t
|
||||
cemoji = self.bot.config.emojis.getemoji('coin')
|
||||
targets: list[Union[discord.User, discord.Member]]
|
||||
|
||||
if not ctx.guild:
|
||||
# Added for the typechecker
|
||||
# This is impossible from the guild_only ward
|
||||
return
|
||||
if not self.bot.core:
|
||||
return
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
if isinstance(target, discord.Role):
|
||||
targets = [mem for mem in target.members if not mem.bot]
|
||||
role = target
|
||||
else:
|
||||
targets = [target]
|
||||
role = None
|
||||
|
||||
if role and not targets:
|
||||
# Guard against provided target role having no members
|
||||
# Possible chunking failed for this guild, want to explicitly inform.
|
||||
await ctx.reply(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:economy_balance|error:no_target',
|
||||
"There are no valid members in {role.mention}! It has a total of `0` LC."
|
||||
)).format(role=target)
|
||||
),
|
||||
ephemeral=True
|
||||
)
|
||||
elif not role and target.bot:
|
||||
# Guard against reading or modifying a bot account
|
||||
await ctx.reply(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:economy_balance|error:target_is_bot',
|
||||
"Bots cannot have coin balances!"
|
||||
))
|
||||
),
|
||||
ephemeral=True
|
||||
)
|
||||
elif set_to is not None and add is not None:
|
||||
# Requested operation doesn't make sense
|
||||
await ctx.reply(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:economy_balance|error:args',
|
||||
"You cannot simultaneously `set` and `add` member balances!"
|
||||
))
|
||||
),
|
||||
ephemeral=True
|
||||
)
|
||||
elif set_to is not None or add is not None:
|
||||
# Setting route
|
||||
# First ensure all the targets we will be updating already have rows
|
||||
# As this is one of the few operations that acts on members not already registered,
|
||||
# We may need to do a mass row create operation.
|
||||
targetids = set(target.id for target in targets)
|
||||
if len(targets) > 1:
|
||||
conn = await ctx.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# First fetch the members which currently exist
|
||||
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
|
||||
query.select('userid').with_no_adapter()
|
||||
if 2 * len(targets) < len(ctx.guild.members):
|
||||
# More efficient to fetch the targets explicitly
|
||||
query.where(userid=list(targetids))
|
||||
existent_rows = await query
|
||||
existentids = set(r['userid'] for r in existent_rows)
|
||||
|
||||
# Then check if any new userids need adding, and if so create them
|
||||
new_ids = targetids.difference(existentids)
|
||||
if new_ids:
|
||||
# We use ON CONFLICT IGNORE here in case the users already exist.
|
||||
await self.bot.core.data.User.table.insert_many(
|
||||
('userid',),
|
||||
*((id,) for id in new_ids)
|
||||
).on_conflict(ignore=True)
|
||||
# TODO: Replace 0 here with the starting_coin value
|
||||
await self.bot.core.data.Member.table.insert_many(
|
||||
('guildid', 'userid', 'coins'),
|
||||
*((ctx.guild.id, id, 0) for id in new_ids)
|
||||
).on_conflict(ignore=True)
|
||||
else:
|
||||
# With only one target, we can take a simpler path, and make better use of local caches.
|
||||
await self.bot.core.lions.fetch(ctx.guild.id, target.id)
|
||||
# Now we are certain these members have a database row
|
||||
|
||||
# Perform the appropriate action
|
||||
if role:
|
||||
affected = t(_np(
|
||||
'cmd:economy_balance|embed:success|affected',
|
||||
"One user was affected.",
|
||||
"**{count}** users were affected.",
|
||||
len(targets)
|
||||
)).format(count=len(targets))
|
||||
conf_affected = t(_np(
|
||||
'cmd:economy_balance|confirm|affected',
|
||||
"One user will be affected.",
|
||||
"**{count}** users will be affected.",
|
||||
len(targets)
|
||||
)).format(count=len(targets))
|
||||
confirm = Confirm(conf_affected)
|
||||
confirm.confirm_button = t(_p(
|
||||
'cmd:economy_balance|confirm|button:confirm',
|
||||
"Yes, adjust balances"
|
||||
))
|
||||
confirm.confirm_button = t(_p(
|
||||
'cmd:economy_balance|confirm|button:cancel',
|
||||
"No, cancel"
|
||||
))
|
||||
if set_to is not None:
|
||||
if role:
|
||||
if role.is_default():
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_set|desc',
|
||||
"All members of **{guild_name}** have had their "
|
||||
"balance set to {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
guild_name=ctx.guild.name,
|
||||
coin_emoji=cemoji,
|
||||
amount=set_to
|
||||
) + '\n' + affected
|
||||
conf_description = t(_p(
|
||||
'cmd:economy_balance|confirm_set|desc',
|
||||
"Are you sure you want to set everyone's balance to {coin_emoji}**{amount}**?"
|
||||
)).format(
|
||||
coin_emoji=cemoji,
|
||||
amount=set_to
|
||||
) + '\n' + conf_affected
|
||||
else:
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_set|desc',
|
||||
"All members of {role_mention} have had their "
|
||||
"balance set to {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
role_mention=role.mention,
|
||||
coin_emoji=cemoji,
|
||||
amount=set_to
|
||||
) + '\n' + affected
|
||||
conf_description = t(_p(
|
||||
'cmd:economy_balance|confirm_set|desc',
|
||||
"Are you sure you want to set the balance of everyone with {role_mention} "
|
||||
"to {coin_emoji}**{amount}**?"
|
||||
)).format(
|
||||
role_mention=role.mention,
|
||||
coin_emoji=cemoji,
|
||||
amount=set_to
|
||||
) + '\n' + conf_affected
|
||||
confirm.embed.description = conf_description
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResponseTimedOut:
|
||||
return
|
||||
if not result:
|
||||
return
|
||||
else:
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_set|desc',
|
||||
"{user_mention} now has a balance of {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
user_mention=target.mention,
|
||||
coin_emoji=cemoji,
|
||||
amount=set_to
|
||||
)
|
||||
await self.bot.core.data.Member.table.update_where(
|
||||
guildid=ctx.guild.id, userid=list(targetids)
|
||||
).set(
|
||||
coins=set_to
|
||||
)
|
||||
else:
|
||||
if role:
|
||||
if role.is_default():
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_add|desc',
|
||||
"All members of **{guild_name}** have been given "
|
||||
"{coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
guild_name=ctx.guild.name,
|
||||
coin_emoji=cemoji,
|
||||
amount=add
|
||||
) + '\n' + affected
|
||||
conf_description = t(_p(
|
||||
'cmd:economy_balance|confirm_add|desc',
|
||||
"Are you sure you want to add **{amount}** to everyone's balance?"
|
||||
)).format(
|
||||
coin_emoji=cemoji,
|
||||
amount=add
|
||||
) + '\n' + conf_affected
|
||||
else:
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_add|desc',
|
||||
"All members of {role_mention} have been given "
|
||||
"{coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
role_mention=role.mention,
|
||||
coin_emoji=cemoji,
|
||||
amount=add
|
||||
) + '\n' + affected
|
||||
conf_description = t(_p(
|
||||
'cmd:economy_balance|confirm_add|desc',
|
||||
"Are you sure you want to add {coin_emoji}**{amount}** to everyone in {role_mention}?"
|
||||
)).format(
|
||||
coin_emoji=cemoji,
|
||||
amount=add,
|
||||
role_mention=role.mention
|
||||
) + '\n' + conf_affected
|
||||
confirm.embed.description = conf_description
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResponseTimedOut:
|
||||
return
|
||||
if not result:
|
||||
return
|
||||
results = await self.bot.core.data.Member.table.update_where(
|
||||
guildid=ctx.guild.id, userid=list(targetids)
|
||||
).set(
|
||||
coins=(self.bot.core.data.Member.coins + add)
|
||||
)
|
||||
# Single member case occurs afterwards so we can pick up the results
|
||||
if not role:
|
||||
description = t(_p(
|
||||
'cmd:economy_balance|embed:success_add|desc',
|
||||
"{user_mention} was given {coin_emoji}**{amount}**, and "
|
||||
"now has a balance of {coin_emoji}**{new_amount}**."
|
||||
)).format(
|
||||
user_mention=target.mention,
|
||||
coin_emoji=cemoji,
|
||||
amount=add,
|
||||
new_amount=results[0]['coins']
|
||||
)
|
||||
|
||||
title = t(_np(
|
||||
'cmd:economy_balance|embed:success|title',
|
||||
"Account successfully updated.",
|
||||
"Accounts successfully updated.",
|
||||
len(targets)
|
||||
))
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
description=description,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Viewing route
|
||||
MemModel = self.bot.core.data.Member
|
||||
if role:
|
||||
query = MemModel.fetch_where(
|
||||
(MemModel.guildid == role.guild.id) & (MemModel.coins != 0)
|
||||
)
|
||||
query.order_by('coins', ORDER.DESC)
|
||||
if not role.is_default():
|
||||
# Everyone role is handled differently for data efficiency
|
||||
ids = [target.id for target in targets]
|
||||
query = query.where(userid=ids)
|
||||
rows = await query
|
||||
|
||||
name = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|author',
|
||||
"Balance sheet for {name}"
|
||||
)).format(name=role.name if not role.is_default() else role.guild.name)
|
||||
if rows:
|
||||
if role.is_default():
|
||||
header = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|header',
|
||||
"This server has a total balance of {coin_emoji}**{total}**."
|
||||
)).format(
|
||||
coin_emoji=cemoji,
|
||||
total=sum(row.coins for row in rows)
|
||||
)
|
||||
else:
|
||||
header = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|header',
|
||||
"{role_mention} has `{count}` members with non-zero balance, "
|
||||
"with a total balance of {coin_emoji}**{total}**."
|
||||
)).format(
|
||||
count=len(targets),
|
||||
role_mention=role.mention,
|
||||
total=sum(row.coins for row in rows),
|
||||
coin_emoji=cemoji
|
||||
)
|
||||
|
||||
# Build the leaderboard:
|
||||
lb_format = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|row_format',
|
||||
"`[{pos:>{numwidth}}]` | `{coins:>{coinwidth}} LC` | {mention}"
|
||||
))
|
||||
|
||||
blocklen = 20
|
||||
blocks = [rows[i:i+blocklen] for i in range(0, len(rows), blocklen)]
|
||||
paged = len(blocks) > 1
|
||||
pages = []
|
||||
for i, block in enumerate(blocks):
|
||||
lines = []
|
||||
numwidth = len(str(i + len(block)))
|
||||
coinwidth = len(str(max(row.coins for row in rows)))
|
||||
for j, row in enumerate(block, start=i):
|
||||
lines.append(
|
||||
lb_format.format(
|
||||
pos=j, numwidth=numwidth,
|
||||
coins=row.coins, coinwidth=coinwidth,
|
||||
mention=f"<@{row.userid}>"
|
||||
)
|
||||
)
|
||||
lb_block = '\n'.join(lines)
|
||||
embed = discord.Embed(
|
||||
description=f"{header}\n{lb_block}"
|
||||
)
|
||||
embed.set_author(name=name)
|
||||
if paged:
|
||||
embed.set_footer(
|
||||
text=t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|footer',
|
||||
"Page {page}/{total}"
|
||||
)).format(page=i+1, total=len(blocks))
|
||||
)
|
||||
pages.append(MessageArgs(embed=embed))
|
||||
pager = Pager(pages, show_cancel=True)
|
||||
await pager.run(ctx.interaction)
|
||||
else:
|
||||
if role.is_default():
|
||||
header = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|header',
|
||||
"This server has a total balance of {coin_emoji}**0**."
|
||||
)).format(
|
||||
coin_emoji=cemoji,
|
||||
)
|
||||
else:
|
||||
header = t(_p(
|
||||
'cmd:economy_balance|embed:role_lb|header',
|
||||
"The role {role_mention} has a total balance of {coin_emoji}**0**."
|
||||
)).format(
|
||||
role_mention=role.mention,
|
||||
coin_emoji=cemoji
|
||||
)
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=header
|
||||
)
|
||||
embed.set_author(name=name)
|
||||
await ctx.reply(embed=embed)
|
||||
else:
|
||||
# If we have a single target, show their current balance, with a short transaction history.
|
||||
user = targets[0]
|
||||
row = await self.bot.core.data.Member.fetch(ctx.guild.id, user.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=t(_p(
|
||||
'cmd:economy_balance|embed:single|desc',
|
||||
"{mention} currently owns {coin_emoji} {coins}."
|
||||
)).format(
|
||||
mention=user.mention,
|
||||
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
||||
coins=row.coins
|
||||
)
|
||||
).set_author(
|
||||
icon_url=user.avatar,
|
||||
name=t(_p(
|
||||
'cmd:economy_balance|embed:single|author',
|
||||
"Balance statement for {user}"
|
||||
)).format(user=str(user))
|
||||
)
|
||||
await ctx.reply(
|
||||
embed=embed
|
||||
)
|
||||
# TODO: Add small transaction history block when we have transaction formatter
|
||||
|
||||
@economy_group.command(
|
||||
name=_p('cmd:economy_reset', "reset"),
|
||||
description=_p(
|
||||
'cmd:economy_reset|desc',
|
||||
"Reset the coin balance for a target user or role. (See also \"economy balance\".)"
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
target=_p('cmd:economy_reset|param:target', "target"),
|
||||
)
|
||||
@appcmds.describe(
|
||||
target=_p(
|
||||
'cmd:economy_reset|param:target|desc',
|
||||
"Target user or role to view or update. Use @everyone to reset the entire guild."
|
||||
),
|
||||
)
|
||||
async def economy_reset_cmd(
|
||||
self,
|
||||
ctx: LionContext,
|
||||
target: discord.User | discord.Member | discord.Role,
|
||||
):
|
||||
# TODO: Permission check
|
||||
t = self.bot.translator.t
|
||||
starting_balance = 0
|
||||
coin_emoji = self.bot.config.emojis.getemoji('coin')
|
||||
|
||||
# Typechecker guards
|
||||
if not ctx.guild:
|
||||
return
|
||||
if not ctx.bot.core:
|
||||
return
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
if isinstance(target, discord.Role):
|
||||
if target.is_default():
|
||||
# Confirm: Reset Guild
|
||||
confirm_msg = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_guild|desc',
|
||||
"Are you sure you want to reset the coin balance for everyone in **{guild_name}**?\n"
|
||||
"*This is not reversible!*"
|
||||
)).format(
|
||||
guild_name=ctx.guild.name
|
||||
)
|
||||
confirm = Confirm(confirm_msg)
|
||||
confirm.confirm_button.label = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_guild|button:confirm',
|
||||
"Yes, reset the economy"
|
||||
))
|
||||
confirm.cancel_button.label = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_guild|button:cancel',
|
||||
"Cancel reset"
|
||||
))
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResponseTimedOut:
|
||||
return
|
||||
|
||||
if result:
|
||||
# Complete reset
|
||||
await ctx.bot.core.data.Member.table.update_where(
|
||||
guildid=ctx.guild.id,
|
||||
).set(coins=starting_balance)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
description=t(_p(
|
||||
'cmd:economy_reset|embed:success_guild|desc',
|
||||
"Everyone in **{guild_name}** has had their balance reset to {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
guild_name=ctx.guild.name,
|
||||
coin_emoji=coin_emoji,
|
||||
amount=starting_balance
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Provided a role to reset
|
||||
targets = [member for member in target.members if not member.bot]
|
||||
if not targets:
|
||||
# Error: No targets
|
||||
await ctx.reply(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:economy_reset|error:no_target|desc',
|
||||
"The role {mention} has no members to reset!"
|
||||
)).format(mention=target.mention)
|
||||
),
|
||||
ephemeral=True
|
||||
)
|
||||
else:
|
||||
# Confirm: Reset Role
|
||||
# Include number of people affected
|
||||
confirm_msg = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_role|desc',
|
||||
"Are you sure you want to reset the balance for everyone in {mention}?\n"
|
||||
"**{count}** members will be affected."
|
||||
)).format(
|
||||
mention=target.mention,
|
||||
count=len(targets)
|
||||
)
|
||||
confirm = Confirm(confirm_msg)
|
||||
confirm.confirm_button.label = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_role|button:confirm',
|
||||
"Yes, complete economy reset"
|
||||
))
|
||||
confirm.cancel_button.label = t(_p(
|
||||
'cmd:economy_reset|confirm:reset_role|button:cancel',
|
||||
"Cancel"
|
||||
))
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||
except ResponseTimedOut:
|
||||
return
|
||||
|
||||
if result:
|
||||
# Complete reset
|
||||
await ctx.bot.core.data.Member.table.update_where(
|
||||
guildid=ctx.guild.id,
|
||||
userid=[t.id for t in targets],
|
||||
).set(coins=starting_balance)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
description=t(_p(
|
||||
'cmd:economy_reset|embed:success_role|desc',
|
||||
"Everyone in {role_mention} has had their "
|
||||
"coin balance reset to {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
mention=target.mention,
|
||||
coin_emoji=coin_emoji,
|
||||
amount=starting_balance
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Provided an individual user.
|
||||
# Reset their balance
|
||||
# Do not create the member row if it does not already exist.
|
||||
# TODO: Audit logging trail
|
||||
await ctx.bot.core.data.Member.table.update_where(
|
||||
guuildid=ctx.guild.id,
|
||||
userid=target.id,
|
||||
).set(coins=starting_balance)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
description=t(_p(
|
||||
'cmd:economy_reset|embed:success_user|desc',
|
||||
"{mention}'s balance has been reset to {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
mention=target.mention,
|
||||
coin_emoji=coin_emoji,
|
||||
amount=starting_balance
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@cmds.hybrid_command(
|
||||
name=_p('cmd:send', "send"),
|
||||
description=_p(
|
||||
'cmd:send|desc',
|
||||
"Gift the target user a certain number of LionCoins."
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
target=_p('cmd:send|param:target', "target"),
|
||||
amount=_p('cmd:send|param:amount', "amount"),
|
||||
note=_p('cmd:send|param:note', "note")
|
||||
)
|
||||
@appcmds.describe(
|
||||
target=_p('cmd:send|param:target|desc', "User to send the gift to"),
|
||||
amount=_p('cmd:send|param:amount|desc', "Number of coins to send"),
|
||||
note=_p('cmd:send|param:note|desc', "Optional note to add to the gift.")
|
||||
)
|
||||
@appcmds.guild_only()
|
||||
async def send_cmd(self, ctx: LionContext,
|
||||
target: discord.User | discord.Member,
|
||||
amount: appcmds.Range[int, 1, MAX_COINS],
|
||||
note: Optional[str] = None):
|
||||
"""
|
||||
Send `amount` lioncoins to the provided `target`, with the optional `note` attached.
|
||||
"""
|
||||
if not ctx.interaction:
|
||||
return
|
||||
if not ctx.guild:
|
||||
return
|
||||
if not self.bot.core:
|
||||
return
|
||||
|
||||
t = self.bot.translator.t
|
||||
Member = self.bot.core.data.Member
|
||||
target_lion = await self.bot.core.lions.fetch(ctx.guild.id, target.id)
|
||||
|
||||
# TODO: Add a "Send thanks" button to the DM?
|
||||
# Alternative flow could be waiting until the target user presses accept
|
||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# We do this in a transaction so that if something goes wrong,
|
||||
# the coins deduction is rolled back atomicly
|
||||
balance = ctx.alion.data.coins
|
||||
if amount > balance:
|
||||
await ctx.interaction.edit_original_response(
|
||||
embed=error_embed(
|
||||
t(_p(
|
||||
'cmd:send|error:insufficient',
|
||||
"You do not have enough lioncoins to do this!\n"
|
||||
"`Current Balance:` {coin_emoji}{balance}"
|
||||
)).format(
|
||||
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
||||
balance=balance
|
||||
)
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Transfer the coins
|
||||
await ctx.alion.data.update(coins=(Member.coins - amount))
|
||||
await target_lion.data.update(coins=(Member.coins + amount))
|
||||
|
||||
# TODO: Audit trail
|
||||
|
||||
# Message target
|
||||
embed = discord.Embed(
|
||||
title=t(_p(
|
||||
'cmd:send|embed:gift|title',
|
||||
"{user} sent you a gift!"
|
||||
)).format(user=ctx.author.name),
|
||||
description=t(_p(
|
||||
'cmd:send|embed:gift|desc',
|
||||
"{mention} sent you {coin_emoji}**{amount}**."
|
||||
)).format(
|
||||
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
||||
amount=amount,
|
||||
mention=ctx.author.mention
|
||||
),
|
||||
timestamp=utc_now()
|
||||
)
|
||||
if note:
|
||||
embed.add_field(
|
||||
name="Note Attached",
|
||||
value=note
|
||||
)
|
||||
try:
|
||||
await target.send(embed=embed)
|
||||
failed = False
|
||||
except discord.HTTPException:
|
||||
failed = True
|
||||
pass
|
||||
|
||||
# Ack transfer
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
description=t(_p(
|
||||
'cmd:send|embed:ack|desc',
|
||||
"**{coin_emoji}{amount}** has been deducted from your balance and sent to {mention}!"
|
||||
)).format(
|
||||
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
||||
amount=amount,
|
||||
mention=target.mention
|
||||
)
|
||||
)
|
||||
if failed:
|
||||
embed.description = t(_p(
|
||||
'cmd:send|embed:ack|desc|error:unreachable',
|
||||
"Unfortunately, I was not able to message the recipient. Perhaps they have me blocked?"
|
||||
))
|
||||
await ctx.interaction.edit_original_response(embed=embed)
|
||||
16
src/modules/pending-rewrite/__init__.py
Normal file
16
src/modules/pending-rewrite/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .sysadmin import *
|
||||
from .guild_admin import *
|
||||
from .meta import *
|
||||
from .economy import *
|
||||
from .study import *
|
||||
from .stats import *
|
||||
from .user_config import *
|
||||
from .workout import *
|
||||
from .todo import *
|
||||
from .topgg import *
|
||||
from .reminders import *
|
||||
from .renting import *
|
||||
from .moderation import *
|
||||
from .accountability import *
|
||||
from .plugins import *
|
||||
from .sponsors import *
|
||||
477
src/modules/pending-rewrite/accountability/TimeSlot.py
Normal file
477
src/modules/pending-rewrite/accountability/TimeSlot.py
Normal file
@@ -0,0 +1,477 @@
|
||||
from typing import List, Dict
|
||||
import datetime
|
||||
import discord
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from settings import GuildSettings
|
||||
from utils.lib import tick, cross
|
||||
from core import Lion
|
||||
from meta import client
|
||||
|
||||
from .lib import utc_now
|
||||
from .data import accountability_members, accountability_rooms
|
||||
|
||||
|
||||
class SlotMember:
|
||||
"""
|
||||
Class representing a member booked into an accountability room.
|
||||
Mostly acts as an interface to the corresponding TableRow.
|
||||
But also stores the discord.Member associated, and has several computed properties.
|
||||
The member may be None.
|
||||
"""
|
||||
___slots__ = ('slotid', 'userid', 'guild')
|
||||
|
||||
def __init__(self, slotid, userid, guild):
|
||||
self.slotid = slotid
|
||||
self.userid = userid
|
||||
self.guild = guild
|
||||
|
||||
self._member = None
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return (self.slotid, self.userid)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return accountability_members.fetch(self.key)
|
||||
|
||||
@property
|
||||
def member(self):
|
||||
return self.guild.get_member(self.userid)
|
||||
|
||||
@property
|
||||
def has_attended(self):
|
||||
return self.data.duration > 0 or self.data.last_joined_at
|
||||
|
||||
|
||||
class TimeSlot:
|
||||
"""
|
||||
Class representing an accountability slot.
|
||||
"""
|
||||
__slots__ = (
|
||||
'guild',
|
||||
'start_time',
|
||||
'data',
|
||||
'lobby',
|
||||
'category',
|
||||
'channel',
|
||||
'message',
|
||||
'members'
|
||||
)
|
||||
|
||||
slots = {}
|
||||
|
||||
_member_overwrite = discord.PermissionOverwrite(
|
||||
view_channel=True,
|
||||
connect=True
|
||||
)
|
||||
|
||||
_everyone_overwrite = discord.PermissionOverwrite(
|
||||
view_channel=False,
|
||||
connect=False,
|
||||
speak=False
|
||||
)
|
||||
|
||||
happy_lion = "https://media.discordapp.net/stickers/898266283559227422.png"
|
||||
sad_lion = "https://media.discordapp.net/stickers/898266548421148723.png"
|
||||
|
||||
def __init__(self, guild, start_time, data=None):
|
||||
self.guild: discord.Guild = guild
|
||||
self.start_time: datetime.datetime = start_time
|
||||
self.data = data
|
||||
|
||||
self.lobby: discord.TextChannel = None # Text channel to post the slot status
|
||||
self.category: discord.CategoryChannel = None # Category to create the voice rooms in
|
||||
self.channel: discord.VoiceChannel = None # Text channel associated with this time slot
|
||||
self.message: discord.Message = None # Status message in lobby channel
|
||||
|
||||
self.members: Dict[int, SlotMember] = {} # memberid -> SlotMember
|
||||
|
||||
@property
|
||||
def open_embed(self):
|
||||
timestamp = int(self.start_time.timestamp())
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Session <t:{}:t> - <t:{}:t>".format(
|
||||
timestamp, timestamp + 3600
|
||||
),
|
||||
colour=discord.Colour.orange(),
|
||||
timestamp=self.start_time
|
||||
).set_footer(
|
||||
text="About to start!\nJoin the session with {}schedule book".format(client.prefix)
|
||||
)
|
||||
|
||||
if self.members:
|
||||
embed.description = "Starting <t:{}:R>.".format(timestamp)
|
||||
embed.add_field(
|
||||
name="Members",
|
||||
value=(
|
||||
', '.join('<@{}>'.format(key) for key in self.members.keys())
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed.description = "No members scheduled this session!"
|
||||
|
||||
return embed
|
||||
|
||||
@property
|
||||
def status_embed(self):
|
||||
timestamp = int(self.start_time.timestamp())
|
||||
embed = discord.Embed(
|
||||
title="Session <t:{}:t> - <t:{}:t>".format(
|
||||
timestamp, timestamp + 3600
|
||||
),
|
||||
description="Finishing <t:{}:R>.".format(timestamp + 3600),
|
||||
colour=discord.Colour.orange(),
|
||||
timestamp=self.start_time
|
||||
).set_footer(text="Join the next session using {}schedule book".format(client.prefix))
|
||||
|
||||
if self.members:
|
||||
classifications = {
|
||||
"Attended": [],
|
||||
"Studying Now": [],
|
||||
"Waiting for": []
|
||||
}
|
||||
for memid, mem in self.members.items():
|
||||
mention = '<@{}>'.format(memid)
|
||||
if not mem.has_attended:
|
||||
classifications["Waiting for"].append(mention)
|
||||
elif mem.member in self.channel.members:
|
||||
classifications["Studying Now"].append(mention)
|
||||
else:
|
||||
classifications["Attended"].append(mention)
|
||||
|
||||
all_attended = all(mem.has_attended for mem in self.members.values())
|
||||
bonus_line = (
|
||||
"{tick} Everyone attended, and will get a `{bonus} LC` bonus!".format(
|
||||
tick=tick,
|
||||
bonus=GuildSettings(self.guild.id).accountability_bonus.value
|
||||
)
|
||||
if all_attended else ""
|
||||
)
|
||||
if all_attended:
|
||||
embed.set_thumbnail(url=self.happy_lion)
|
||||
|
||||
embed.description += "\n" + bonus_line
|
||||
for field, value in classifications.items():
|
||||
if value:
|
||||
embed.add_field(name=field, value='\n'.join(value))
|
||||
else:
|
||||
embed.description = "No members scheduled this session!"
|
||||
|
||||
return embed
|
||||
|
||||
@property
|
||||
def summary_embed(self):
|
||||
timestamp = int(self.start_time.timestamp())
|
||||
embed = discord.Embed(
|
||||
title="Session <t:{}:t> - <t:{}:t>".format(
|
||||
timestamp, timestamp + 3600
|
||||
),
|
||||
description="Finished <t:{}:R>.".format(timestamp + 3600),
|
||||
colour=discord.Colour.orange(),
|
||||
timestamp=self.start_time
|
||||
).set_footer(text="Completed!")
|
||||
|
||||
if self.members:
|
||||
classifications = {
|
||||
"Attended": [],
|
||||
"Missing": []
|
||||
}
|
||||
for memid, mem in sorted(self.members.items(), key=lambda mem: mem[1].data.duration, reverse=True):
|
||||
mention = '<@{}>'.format(memid)
|
||||
if mem.has_attended:
|
||||
classifications["Attended"].append(
|
||||
"{} ({}%)".format(mention, (mem.data.duration * 100) // 3600)
|
||||
)
|
||||
else:
|
||||
classifications["Missing"].append(mention)
|
||||
|
||||
all_attended = all(mem.has_attended for mem in self.members.values())
|
||||
|
||||
bonus_line = (
|
||||
"{tick} Everyone attended, and received a `{bonus} LC` bonus!".format(
|
||||
tick=tick,
|
||||
bonus=GuildSettings(self.guild.id).accountability_bonus.value
|
||||
)
|
||||
if all_attended else
|
||||
"{cross} Some members missed the session, so everyone missed out on the bonus!".format(
|
||||
cross=cross
|
||||
)
|
||||
)
|
||||
if all_attended:
|
||||
embed.set_thumbnail(url=self.happy_lion)
|
||||
else:
|
||||
embed.set_thumbnail(url=self.sad_lion)
|
||||
|
||||
embed.description += "\n" + bonus_line
|
||||
for field, value in classifications.items():
|
||||
if value:
|
||||
embed.add_field(name=field, value='\n'.join(value))
|
||||
else:
|
||||
embed.description = "No members scheduled this session!"
|
||||
|
||||
return embed
|
||||
|
||||
def load(self, memberids: List[int] = None):
|
||||
"""
|
||||
Load data and update applicable caches.
|
||||
"""
|
||||
if not self.guild:
|
||||
return self
|
||||
|
||||
# Load setting data
|
||||
self.category = GuildSettings(self.guild.id).accountability_category.value
|
||||
self.lobby = GuildSettings(self.guild.id).accountability_lobby.value
|
||||
|
||||
if self.data:
|
||||
# Load channel
|
||||
if self.data.channelid:
|
||||
self.channel = self.guild.get_channel(self.data.channelid)
|
||||
|
||||
# Load message
|
||||
if self.data.messageid and self.lobby:
|
||||
self.message = discord.PartialMessage(
|
||||
channel=self.lobby,
|
||||
id=self.data.messageid
|
||||
)
|
||||
|
||||
# Load members
|
||||
if memberids:
|
||||
self.members = {
|
||||
memberid: SlotMember(self.data.slotid, memberid, self.guild)
|
||||
for memberid in memberids
|
||||
}
|
||||
|
||||
return self
|
||||
|
||||
async def _reload_members(self, memberids=None):
|
||||
"""
|
||||
Reload the timeslot members from the provided list, or data.
|
||||
Also updates the channel overwrites if required.
|
||||
To be used before the session has started.
|
||||
"""
|
||||
if self.data:
|
||||
if memberids is None:
|
||||
member_rows = accountability_members.fetch_rows_where(slotid=self.data.slotid)
|
||||
memberids = [row.userid for row in member_rows]
|
||||
|
||||
self.members = members = {
|
||||
memberid: SlotMember(self.data.slotid, memberid, self.guild)
|
||||
for memberid in memberids
|
||||
}
|
||||
|
||||
if self.channel:
|
||||
# Check and potentially update overwrites
|
||||
current_overwrites = self.channel.overwrites
|
||||
overwrites = {
|
||||
mem.member: self._member_overwrite
|
||||
for mem in members.values()
|
||||
if mem.member
|
||||
}
|
||||
overwrites[self.guild.default_role] = self._everyone_overwrite
|
||||
if current_overwrites != overwrites:
|
||||
await self.channel.edit(overwrites=overwrites)
|
||||
|
||||
def _refresh(self):
|
||||
"""
|
||||
Refresh the stored data row and reload.
|
||||
"""
|
||||
rows = accountability_rooms.fetch_rows_where(
|
||||
guildid=self.guild.id,
|
||||
start_at=self.start_time
|
||||
)
|
||||
self.data = rows[0] if rows else None
|
||||
|
||||
memberids = []
|
||||
if self.data:
|
||||
member_rows = accountability_members.fetch_rows_where(
|
||||
slotid=self.data.slotid
|
||||
)
|
||||
memberids = [row.userid for row in member_rows]
|
||||
self.load(memberids=memberids)
|
||||
|
||||
async def open(self):
|
||||
"""
|
||||
Open the accountability room.
|
||||
Creates a new voice channel, and sends the status message.
|
||||
Event logs any issues.
|
||||
Adds the TimeSlot to cache.
|
||||
Returns the (channelid, messageid).
|
||||
"""
|
||||
# Cleanup any non-existent members
|
||||
for memid, mem in list(self.members.items()):
|
||||
if not mem.data or not mem.member:
|
||||
self.members.pop(memid)
|
||||
|
||||
# Calculate overwrites
|
||||
overwrites = {
|
||||
mem.member: self._member_overwrite
|
||||
for mem in self.members.values()
|
||||
}
|
||||
overwrites[self.guild.default_role] = self._everyone_overwrite
|
||||
|
||||
# Create the channel. Log and bail if something went wrong.
|
||||
if self.data and not self.channel:
|
||||
try:
|
||||
self.channel = await self.guild.create_voice_channel(
|
||||
"Upcoming Scheduled Session",
|
||||
overwrites=overwrites,
|
||||
category=self.category
|
||||
)
|
||||
except discord.HTTPException:
|
||||
GuildSettings(self.guild.id).event_log.log(
|
||||
"Failed to create the scheduled session voice channel. Skipping this session.",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
return None
|
||||
elif not self.data:
|
||||
self.channel = None
|
||||
|
||||
# Send the inital status message. Log and bail if something goes wrong.
|
||||
if not self.message:
|
||||
try:
|
||||
self.message = await self.lobby.send(
|
||||
embed=self.open_embed
|
||||
)
|
||||
except discord.HTTPException:
|
||||
GuildSettings(self.guild.id).event_log.log(
|
||||
"Failed to post the status message in the scheduled session lobby {}.\n"
|
||||
"Skipping this session.".format(self.lobby.mention),
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
return None
|
||||
if self.members:
|
||||
await self.channel_notify()
|
||||
return (self.channel.id if self.channel else None, self.message.id)
|
||||
|
||||
async def channel_notify(self, content=None):
|
||||
"""
|
||||
Ghost pings the session members in the lobby channel.
|
||||
"""
|
||||
if self.members:
|
||||
content = content or "Your scheduled session has started! Please join!"
|
||||
out = "{}\n\n{}".format(
|
||||
content,
|
||||
' '.join('<@{}>'.format(memid) for memid, mem in self.members.items() if not mem.has_attended)
|
||||
)
|
||||
out_msg = await self.lobby.send(out)
|
||||
await out_msg.delete()
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Start the accountability room slot.
|
||||
Update the status message, and launch the DM reminder.
|
||||
"""
|
||||
dither = 15 * random.random()
|
||||
await asyncio.sleep(dither)
|
||||
if self.channel:
|
||||
try:
|
||||
await self.channel.edit(name="Scheduled Session Room")
|
||||
await self.channel.set_permissions(self.guild.default_role, view_channel=True, connect=False)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
asyncio.create_task(self.dm_reminder(delay=60))
|
||||
try:
|
||||
await self.message.edit(embed=self.status_embed)
|
||||
except discord.NotFound:
|
||||
try:
|
||||
self.message = await self.lobby.send(
|
||||
embed=self.status_embed
|
||||
)
|
||||
except discord.HTTPException:
|
||||
self.message = None
|
||||
|
||||
async def dm_reminder(self, delay=60):
|
||||
"""
|
||||
Notifies missing members with a direct message after 1 minute.
|
||||
"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="The scheduled session you booked has started!",
|
||||
description="Please join {}.".format(self.channel.mention),
|
||||
colour=discord.Colour.orange()
|
||||
).set_footer(
|
||||
text=self.guild.name,
|
||||
icon_url=self.guild.icon_url
|
||||
)
|
||||
|
||||
members = (mem.member for mem in self.members.values() if not mem.has_attended)
|
||||
members = (member for member in members if member)
|
||||
await asyncio.gather(
|
||||
*(member.send(embed=embed) for member in members),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Delete the channel and update the status message to display a session summary.
|
||||
Unloads the TimeSlot from cache.
|
||||
"""
|
||||
dither = 15 * random.random()
|
||||
await asyncio.sleep(dither)
|
||||
if self.channel:
|
||||
try:
|
||||
await self.channel.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
if self.message:
|
||||
try:
|
||||
await self.message.edit(embed=self.summary_embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Reward members appropriately
|
||||
if self.guild:
|
||||
guild_settings = GuildSettings(self.guild.id)
|
||||
reward = guild_settings.accountability_reward.value
|
||||
if all(mem.has_attended for mem in self.members.values()):
|
||||
reward += guild_settings.accountability_bonus.value
|
||||
|
||||
for memid in self.members:
|
||||
Lion.fetch(self.guild.id, memid).addCoins(reward, bonus=True)
|
||||
|
||||
async def cancel(self):
|
||||
"""
|
||||
Cancel the slot, generally due to missing data.
|
||||
Updates the message and channel if possible, removes slot from cache, and also updates data.
|
||||
# TODO: Refund members
|
||||
"""
|
||||
if self.data:
|
||||
self.data.closed_at = utc_now()
|
||||
|
||||
if self.channel:
|
||||
try:
|
||||
await self.channel.delete()
|
||||
self.channel = None
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
if self.message:
|
||||
try:
|
||||
timestamp = int(self.start_time.timestamp())
|
||||
embed = discord.Embed(
|
||||
title="Session <t:{}:t> - <t:{}:t>".format(
|
||||
timestamp, timestamp + 3600
|
||||
),
|
||||
description="Session canceled!",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
await self.message.edit(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
async def update_status(self):
|
||||
"""
|
||||
Intelligently update the status message.
|
||||
"""
|
||||
if self.message:
|
||||
if utc_now() < self.start_time:
|
||||
await self.message.edit(embed=self.open_embed)
|
||||
elif utc_now() < self.start_time + datetime.timedelta(hours=1):
|
||||
await self.message.edit(embed=self.status_embed)
|
||||
else:
|
||||
await self.message.edit(embed=self.summary_embed)
|
||||
6
src/modules/pending-rewrite/accountability/__init__.py
Normal file
6
src/modules/pending-rewrite/accountability/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .module import module
|
||||
|
||||
from . import data
|
||||
from . import admin
|
||||
from . import commands
|
||||
from . import tracker
|
||||
140
src/modules/pending-rewrite/accountability/admin.py
Normal file
140
src/modules/pending-rewrite/accountability/admin.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
import discord
|
||||
|
||||
import settings
|
||||
from settings import GuildSettings, GuildSetting
|
||||
|
||||
from .tracker import AccountabilityGuild as AG
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class accountability_category(settings.Channel, settings.GuildSetting):
|
||||
category = "Scheduled Sessions"
|
||||
|
||||
attr_name = "accountability_category"
|
||||
_data_column = "accountability_category"
|
||||
|
||||
display_name = "session_category"
|
||||
desc = "Category in which to make the scheduled session rooms."
|
||||
|
||||
_default = None
|
||||
|
||||
long_desc = (
|
||||
"\"Schedule session\" category channel.\n"
|
||||
"Scheduled sessions will be held in voice channels created under this category."
|
||||
)
|
||||
_accepts = "A category channel."
|
||||
|
||||
_chan_type = discord.ChannelType.category
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
# TODO Move this somewhere better
|
||||
if self.id not in AG.cache:
|
||||
AG(self.id)
|
||||
return "The session category has been changed to **{}**.".format(self.value.name)
|
||||
else:
|
||||
return "The scheduled session system has been started in **{}**.".format(self.value.name)
|
||||
else:
|
||||
if self.id in AG.cache:
|
||||
aguild = AG.cache.pop(self.id)
|
||||
if aguild.current_slot:
|
||||
asyncio.create_task(aguild.current_slot.cancel())
|
||||
if aguild.upcoming_slot:
|
||||
asyncio.create_task(aguild.upcoming_slot.cancel())
|
||||
return "The scheduled session system has been shut down."
|
||||
else:
|
||||
return "The scheduled session category has been unset."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class accountability_lobby(settings.Channel, settings.GuildSetting):
|
||||
category = "Scheduled Sessions"
|
||||
|
||||
attr_name = "accountability_lobby"
|
||||
_data_column = attr_name
|
||||
|
||||
display_name = "session_lobby"
|
||||
desc = "Category in which to post scheduled session notifications updates."
|
||||
|
||||
_default = None
|
||||
|
||||
long_desc = (
|
||||
"Scheduled session updates will be posted here, and members will be notified in this channel.\n"
|
||||
"The channel will be automatically created in the configured `session_category` if it does not exist.\n"
|
||||
"Members do not need to be able to write in the channel."
|
||||
)
|
||||
_accepts = "Any text channel."
|
||||
|
||||
_chan_type = discord.ChannelType.text
|
||||
|
||||
async def auto_create(self):
|
||||
# TODO: FUTURE
|
||||
...
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class accountability_price(settings.Integer, GuildSetting):
|
||||
category = "Scheduled Sessions"
|
||||
|
||||
attr_name = "accountability_price"
|
||||
_data_column = attr_name
|
||||
|
||||
display_name = "session_price"
|
||||
desc = "Cost of booking a scheduled session."
|
||||
|
||||
_default = 500
|
||||
|
||||
long_desc = (
|
||||
"The price of booking each one hour scheduled session slot."
|
||||
)
|
||||
_accepts = "An integer number of coins."
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "Scheduled session slots now cost `{}` coins.".format(self.value)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class accountability_bonus(settings.Integer, GuildSetting):
|
||||
category = "Scheduled Sessions"
|
||||
|
||||
attr_name = "accountability_bonus"
|
||||
_data_column = attr_name
|
||||
|
||||
display_name = "session_bonus"
|
||||
desc = "Bonus given when everyone attends a scheduled session slot."
|
||||
|
||||
_default = 750
|
||||
|
||||
long_desc = (
|
||||
"The extra bonus given to each scheduled session member when everyone who booked attended the session."
|
||||
)
|
||||
_accepts = "An integer number of coins."
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "Scheduled session members will now get `{}` coins if everyone joins.".format(self.value)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class accountability_reward(settings.Integer, GuildSetting):
|
||||
category = "Scheduled Sessions"
|
||||
|
||||
attr_name = "accountability_reward"
|
||||
_data_column = attr_name
|
||||
|
||||
display_name = "session_reward"
|
||||
desc = "The individual reward given when a member attends their booked scheduled session."
|
||||
|
||||
_default = 500
|
||||
|
||||
long_desc = (
|
||||
"Reward given to a member who attends a booked scheduled session."
|
||||
)
|
||||
_accepts = "An integer number of coins."
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "Members will now get `{}` coins when they attend their scheduled session.".format(self.value)
|
||||
596
src/modules/pending-rewrite/accountability/commands.py
Normal file
596
src/modules/pending-rewrite/accountability/commands.py
Normal file
@@ -0,0 +1,596 @@
|
||||
import re
|
||||
import datetime
|
||||
import discord
|
||||
import asyncio
|
||||
import contextlib
|
||||
from cmdClient.checks import in_guild
|
||||
|
||||
from meta import client
|
||||
from utils.lib import multiselect_regex, parse_ranges, prop_tabulate
|
||||
from data import NOTNULL
|
||||
from data.conditions import GEQ, LEQ
|
||||
|
||||
from .module import module
|
||||
from .lib import utc_now
|
||||
from .tracker import AccountabilityGuild as AGuild
|
||||
from .tracker import room_lock
|
||||
from .TimeSlot import SlotMember
|
||||
from .data import accountability_members, accountability_member_info, accountability_rooms
|
||||
|
||||
|
||||
hint_icon = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
|
||||
|
||||
|
||||
def time_format(time):
|
||||
diff = (time - utc_now()).total_seconds()
|
||||
if diff < 0:
|
||||
diffstr = "`Right Now!!`"
|
||||
elif diff < 600:
|
||||
diffstr = "`Very soon!!`"
|
||||
elif diff < 3600:
|
||||
diffstr = "`In <1 hour `"
|
||||
else:
|
||||
hours = round(diff / 3600)
|
||||
diffstr = "`In {:>2} hour{}`".format(hours, 's' if hours > 1 else ' ')
|
||||
|
||||
return "{} | <t:{:.0f}:t> - <t:{:.0f}:t>".format(
|
||||
diffstr,
|
||||
time.timestamp(),
|
||||
time.timestamp() + 3600,
|
||||
)
|
||||
|
||||
|
||||
user_locks = {} # Map userid -> ctx
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_exclusive(ctx):
|
||||
"""
|
||||
Cancel any existing exclusive contexts for the author.
|
||||
"""
|
||||
old_ctx = user_locks.pop(ctx.author.id, None)
|
||||
if old_ctx:
|
||||
[task.cancel() for task in old_ctx.tasks]
|
||||
|
||||
user_locks[ctx.author.id] = ctx
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
new_ctx = user_locks.get(ctx.author.id, None)
|
||||
if new_ctx and new_ctx.msg.id == ctx.msg.id:
|
||||
user_locks.pop(ctx.author.id)
|
||||
|
||||
|
||||
@module.cmd(
|
||||
name="schedule",
|
||||
desc="View your schedule, and get rewarded for attending scheduled sessions!",
|
||||
group="Productivity",
|
||||
aliases=('rooms', 'sessions')
|
||||
)
|
||||
@in_guild()
|
||||
async def cmd_rooms(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}schedule
|
||||
{prefix}schedule book
|
||||
{prefix}schedule cancel
|
||||
Description:
|
||||
View your schedule with `{prefix}schedule`.
|
||||
Use `{prefix}schedule book` to schedule a session at a selected time..
|
||||
Use `{prefix}schedule cancel` to cancel a scheduled session.
|
||||
"""
|
||||
lower = ctx.args.lower()
|
||||
splits = lower.split()
|
||||
command = splits[0] if splits else None
|
||||
|
||||
if not ctx.guild_settings.accountability_category.value:
|
||||
return await ctx.error_reply("The scheduled session system isn't set up!")
|
||||
|
||||
# First grab the sessions the member is booked in
|
||||
joined_rows = accountability_member_info.select_where(
|
||||
userid=ctx.author.id,
|
||||
start_at=GEQ(utc_now()),
|
||||
_extra="ORDER BY start_at ASC"
|
||||
)
|
||||
|
||||
if command == 'cancel':
|
||||
if not joined_rows:
|
||||
return await ctx.error_reply("You have no scheduled sessions to cancel!")
|
||||
|
||||
# Show unbooking menu
|
||||
lines = [
|
||||
"`[{:>2}]` | {}".format(i, time_format(row['start_at']))
|
||||
for i, row in enumerate(joined_rows)
|
||||
]
|
||||
out_msg = await ctx.reply(
|
||||
content="Please reply with the number(s) of the sessions you want to cancel. E.g. `1, 3, 5` or `1-3, 7-8`.",
|
||||
embed=discord.Embed(
|
||||
title="Please choose the sessions you want to cancel.",
|
||||
description='\n'.join(lines),
|
||||
colour=discord.Colour.orange()
|
||||
).set_footer(
|
||||
text=(
|
||||
"All times are in your own timezone! Hover over a time to see the date."
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await ctx.cancellable(
|
||||
out_msg,
|
||||
cancel_message="Cancel menu closed, no scheduled sessions were cancelled.",
|
||||
timeout=70
|
||||
)
|
||||
|
||||
def check(msg):
|
||||
valid = msg.channel == ctx.ch and msg.author == ctx.author
|
||||
valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c')
|
||||
return valid
|
||||
|
||||
with ensure_exclusive(ctx):
|
||||
try:
|
||||
message = await ctx.client.wait_for('message', check=check, timeout=60)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
await out_msg.edit(
|
||||
content=None,
|
||||
embed=discord.Embed(
|
||||
description="Cancel menu timed out, no scheduled sessions were cancelled.",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
)
|
||||
await out_msg.clear_reactions()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return
|
||||
|
||||
try:
|
||||
await out_msg.delete()
|
||||
await message.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
if message.content.lower() == 'c':
|
||||
return
|
||||
|
||||
to_cancel = [
|
||||
joined_rows[index]
|
||||
for index in parse_ranges(message.content) if index < len(joined_rows)
|
||||
]
|
||||
if not to_cancel:
|
||||
return await ctx.error_reply("No valid sessions selected for cancellation.")
|
||||
elif any(row['start_at'] < utc_now() for row in to_cancel):
|
||||
return await ctx.error_reply("You can't cancel a running session!")
|
||||
|
||||
slotids = [row['slotid'] for row in to_cancel]
|
||||
async with room_lock:
|
||||
deleted = accountability_members.delete_where(
|
||||
userid=ctx.author.id,
|
||||
slotid=slotids
|
||||
)
|
||||
|
||||
# Handle case where the slot has already opened
|
||||
# TODO: Possible race condition if they open over the hour border? Might never cancel
|
||||
for row in to_cancel:
|
||||
aguild = AGuild.cache.get(row['guildid'], None)
|
||||
if aguild and aguild.upcoming_slot and aguild.upcoming_slot.data:
|
||||
if aguild.upcoming_slot.data.slotid in slotids:
|
||||
aguild.upcoming_slot.members.pop(ctx.author.id, None)
|
||||
if aguild.upcoming_slot.channel:
|
||||
try:
|
||||
await aguild.upcoming_slot.channel.set_permissions(
|
||||
ctx.author,
|
||||
overwrite=None
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
await aguild.upcoming_slot.update_status()
|
||||
break
|
||||
|
||||
ctx.alion.addCoins(sum(row[2] for row in deleted))
|
||||
|
||||
remaining = [row for row in joined_rows if row['slotid'] not in slotids]
|
||||
if not remaining:
|
||||
await ctx.embed_reply("Cancelled all your upcoming scheduled sessions!")
|
||||
else:
|
||||
next_booked_time = min(row['start_at'] for row in remaining)
|
||||
if len(to_cancel) > 1:
|
||||
await ctx.embed_reply(
|
||||
"Cancelled `{}` upcoming sessions!\nYour next session is at <t:{:.0f}>.".format(
|
||||
len(to_cancel),
|
||||
next_booked_time.timestamp()
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.embed_reply(
|
||||
"Cancelled your session at <t:{:.0f}>!\n"
|
||||
"Your next session is at <t:{:.0f}>.".format(
|
||||
to_cancel[0]['start_at'].timestamp(),
|
||||
next_booked_time.timestamp()
|
||||
)
|
||||
)
|
||||
elif command == 'book':
|
||||
# Show booking menu
|
||||
# Get attendee count
|
||||
rows = accountability_member_info.select_where(
|
||||
guildid=ctx.guild.id,
|
||||
userid=NOTNULL,
|
||||
select_columns=(
|
||||
'slotid',
|
||||
'start_at',
|
||||
'COUNT(*) as num'
|
||||
),
|
||||
_extra="GROUP BY start_at, slotid"
|
||||
)
|
||||
attendees = {row['start_at']: row['num'] for row in rows}
|
||||
attendee_pad = max((len(str(num)) for num in attendees.values()), default=1)
|
||||
|
||||
# Build lines
|
||||
already_joined_times = set(row['start_at'] for row in joined_rows)
|
||||
start_time = utc_now().replace(minute=0, second=0, microsecond=0)
|
||||
times = (
|
||||
start_time + datetime.timedelta(hours=n)
|
||||
for n in range(1, 25)
|
||||
)
|
||||
times = [
|
||||
time for time in times
|
||||
if time not in already_joined_times and (time - utc_now()).total_seconds() > 660
|
||||
]
|
||||
lines = [
|
||||
"`[{num:>2}]` | `{count:>{count_pad}}` attending | {time}".format(
|
||||
num=i,
|
||||
count=attendees.get(time, 0), count_pad=attendee_pad,
|
||||
time=time_format(time),
|
||||
)
|
||||
for i, time in enumerate(times)
|
||||
]
|
||||
# TODO: Nicer embed
|
||||
# TODO: Don't allow multi bookings if the member has a bad attendance rate
|
||||
out_msg = await ctx.reply(
|
||||
content=(
|
||||
"Please reply with the number(s) of the sessions you want to book. E.g. `1, 3, 5` or `1-3, 7-8`."
|
||||
),
|
||||
embed=discord.Embed(
|
||||
title="Please choose the sessions you want to schedule.",
|
||||
description='\n'.join(lines),
|
||||
colour=discord.Colour.orange()
|
||||
).set_footer(
|
||||
text=(
|
||||
"All times are in your own timezone! Hover over a time to see the date."
|
||||
)
|
||||
)
|
||||
)
|
||||
await ctx.cancellable(
|
||||
out_msg,
|
||||
cancel_message="Booking menu cancelled, no sessions were booked.",
|
||||
timeout=60
|
||||
)
|
||||
|
||||
def check(msg):
|
||||
valid = msg.channel == ctx.ch and msg.author == ctx.author
|
||||
valid = valid and (re.search(multiselect_regex, msg.content) or msg.content.lower() == 'c')
|
||||
return valid
|
||||
|
||||
with ensure_exclusive(ctx):
|
||||
try:
|
||||
message = await ctx.client.wait_for('message', check=check, timeout=30)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
await out_msg.edit(
|
||||
content=None,
|
||||
embed=discord.Embed(
|
||||
description="Booking menu timed out, no sessions were booked.",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
)
|
||||
await out_msg.clear_reactions()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return
|
||||
|
||||
try:
|
||||
await out_msg.delete()
|
||||
await message.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
if message.content.lower() == 'c':
|
||||
return
|
||||
|
||||
to_book = [
|
||||
times[index]
|
||||
for index in parse_ranges(message.content) if index < len(times)
|
||||
]
|
||||
if not to_book:
|
||||
return await ctx.error_reply("No valid sessions selected.")
|
||||
elif any(time < utc_now() for time in to_book):
|
||||
return await ctx.error_reply("You can't book a running session!")
|
||||
cost = len(to_book) * ctx.guild_settings.accountability_price.value
|
||||
if cost > ctx.alion.coins:
|
||||
return await ctx.error_reply(
|
||||
"Sorry, booking `{}` sessions costs `{}` coins, and you only have `{}`!".format(
|
||||
len(to_book),
|
||||
cost,
|
||||
ctx.alion.coins
|
||||
)
|
||||
)
|
||||
|
||||
# Add the member to data, creating the row if required
|
||||
slot_rows = accountability_rooms.fetch_rows_where(
|
||||
guildid=ctx.guild.id,
|
||||
start_at=to_book
|
||||
)
|
||||
slotids = [row.slotid for row in slot_rows]
|
||||
to_add = set(to_book).difference((row.start_at for row in slot_rows))
|
||||
if to_add:
|
||||
slotids.extend(row['slotid'] for row in accountability_rooms.insert_many(
|
||||
*((ctx.guild.id, start_at) for start_at in to_add),
|
||||
insert_keys=('guildid', 'start_at'),
|
||||
))
|
||||
accountability_members.insert_many(
|
||||
*((slotid, ctx.author.id, ctx.guild_settings.accountability_price.value) for slotid in slotids),
|
||||
insert_keys=('slotid', 'userid', 'paid')
|
||||
)
|
||||
|
||||
# Handle case where the slot has already opened
|
||||
# TODO: Fix this, doesn't always work
|
||||
aguild = AGuild.cache.get(ctx.guild.id, None)
|
||||
if aguild:
|
||||
if aguild.upcoming_slot and aguild.upcoming_slot.start_time in to_book:
|
||||
slot = aguild.upcoming_slot
|
||||
if not slot.data:
|
||||
# Handle slot activation
|
||||
slot._refresh()
|
||||
channelid, messageid = await slot.open()
|
||||
accountability_rooms.update_where(
|
||||
{'channelid': channelid, 'messageid': messageid},
|
||||
slotid=slot.data.slotid
|
||||
)
|
||||
else:
|
||||
slot.members[ctx.author.id] = SlotMember(slot.data.slotid, ctx.author.id, ctx.guild)
|
||||
# Also update the channel permissions
|
||||
try:
|
||||
await slot.channel.set_permissions(ctx.author, view_channel=True, connect=True)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
await slot.update_status()
|
||||
ctx.alion.addCoins(-cost)
|
||||
|
||||
# Ack purchase
|
||||
embed = discord.Embed(
|
||||
title="You have scheduled the following session{}!".format('s' if len(to_book) > 1 else ''),
|
||||
description=(
|
||||
"*Please attend all your scheduled sessions!*\n"
|
||||
"*If you can't attend, cancel with* `{}schedule cancel`\n\n{}"
|
||||
).format(
|
||||
ctx.best_prefix,
|
||||
'\n'.join(time_format(time) for time in to_book),
|
||||
),
|
||||
colour=discord.Colour.orange()
|
||||
).set_footer(
|
||||
text=(
|
||||
"Use {prefix}schedule to see your current schedule.\n"
|
||||
).format(prefix=ctx.best_prefix)
|
||||
)
|
||||
try:
|
||||
await ctx.reply(
|
||||
embed=embed,
|
||||
reference=ctx.msg
|
||||
)
|
||||
except discord.NotFound:
|
||||
await ctx.reply(embed=embed)
|
||||
else:
|
||||
# Show accountability room information for this user
|
||||
# Accountability profile
|
||||
# Author
|
||||
# Special case for no past bookings, emphasis hint
|
||||
# Hint on Bookings section for booking/cancelling as applicable
|
||||
# Description has stats
|
||||
# Footer says that all times are in their timezone
|
||||
# TODO: attendance requirement shouldn't be retroactive! Add attended data column
|
||||
# Attended `{}` out of `{}` booked (`{}%` attendance rate!)
|
||||
# Attendance streak: `{}` days attended with no missed sessions!
|
||||
# Add explanation for first time users
|
||||
|
||||
# Get all slots the member has ever booked
|
||||
history = accountability_member_info.select_where(
|
||||
userid=ctx.author.id,
|
||||
# start_at=LEQ(utc_now() - datetime.timedelta(hours=1)),
|
||||
start_at=LEQ(utc_now()),
|
||||
select_columns=("*", "(duration > 0 OR last_joined_at IS NOT NULL) AS attended"),
|
||||
_extra="ORDER BY start_at DESC"
|
||||
)
|
||||
|
||||
if not (history or joined_rows):
|
||||
# First-timer information
|
||||
about = (
|
||||
"You haven't scheduled any study sessions yet!\n"
|
||||
"Schedule a session by typing **`{}schedule book`** and selecting "
|
||||
"the hours you intend to study, "
|
||||
"then attend by joining the session voice channel when it starts!\n"
|
||||
"Only if everyone attends will they get the bonus of `{}` LionCoins!\n"
|
||||
"Let's all do our best and keep each other accountable 🔥"
|
||||
).format(
|
||||
ctx.best_prefix,
|
||||
ctx.guild_settings.accountability_bonus.value
|
||||
)
|
||||
embed = discord.Embed(
|
||||
description=about,
|
||||
colour=discord.Colour.orange()
|
||||
)
|
||||
embed.set_footer(
|
||||
text="Please keep your DMs open so I can notify you when the session starts!\n",
|
||||
icon_url=hint_icon
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
else:
|
||||
# Build description with stats
|
||||
if history:
|
||||
# First get the counts
|
||||
attended_count = sum(row['attended'] for row in history)
|
||||
total_count = len(history)
|
||||
total_duration = sum(row['duration'] for row in history)
|
||||
|
||||
# Add current session to duration if it exists
|
||||
if history[0]['last_joined_at'] and (utc_now() - history[0]['start_at']).total_seconds() < 3600:
|
||||
total_duration += int((utc_now() - history[0]['last_joined_at']).total_seconds())
|
||||
|
||||
# Calculate the streak
|
||||
timezone = ctx.alion.settings.timezone.value
|
||||
|
||||
streak = 0
|
||||
current_streak = None
|
||||
max_streak = 0
|
||||
day_attended = None
|
||||
date = utc_now().astimezone(timezone).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
daydiff = datetime.timedelta(days=1)
|
||||
|
||||
i = 0
|
||||
while i < len(history):
|
||||
row = history[i]
|
||||
i += 1
|
||||
if not row['attended']:
|
||||
# Not attended, streak broken
|
||||
pass
|
||||
elif row['start_at'] > date:
|
||||
# They attended this day
|
||||
day_attended = True
|
||||
continue
|
||||
elif day_attended is None:
|
||||
# Didn't attend today, but don't break streak
|
||||
day_attended = False
|
||||
date -= daydiff
|
||||
i -= 1
|
||||
continue
|
||||
elif not day_attended:
|
||||
# Didn't attend the day, streak broken
|
||||
date -= daydiff
|
||||
i -= 1
|
||||
pass
|
||||
else:
|
||||
# Attended the day
|
||||
streak += 1
|
||||
|
||||
# Move window to the previous day and try the row again
|
||||
date -= daydiff
|
||||
day_attended = False
|
||||
i -= 1
|
||||
continue
|
||||
|
||||
max_streak = max(max_streak, streak)
|
||||
if current_streak is None:
|
||||
current_streak = streak
|
||||
streak = 0
|
||||
|
||||
# Handle loop exit state, i.e. the last streak
|
||||
if day_attended:
|
||||
streak += 1
|
||||
max_streak = max(max_streak, streak)
|
||||
if current_streak is None:
|
||||
current_streak = streak
|
||||
|
||||
# Build the stats
|
||||
table = {
|
||||
"Sessions": "**{}** attended out of **{}**, `{:.0f}%` attendance rate.".format(
|
||||
attended_count,
|
||||
total_count,
|
||||
(attended_count * 100) / total_count,
|
||||
),
|
||||
"Time": "**{:02}:{:02}** in scheduled sessions.".format(
|
||||
total_duration // 3600,
|
||||
(total_duration % 3600) // 60
|
||||
),
|
||||
"Streak": "**{}** day{} with no missed sessions! (Longest: **{}** day{}.)".format(
|
||||
current_streak,
|
||||
's' if current_streak != 1 else '',
|
||||
max_streak,
|
||||
's' if max_streak != 1 else '',
|
||||
),
|
||||
}
|
||||
desc = prop_tabulate(*zip(*table.items()))
|
||||
else:
|
||||
desc = (
|
||||
"Good luck with your next session!\n"
|
||||
)
|
||||
|
||||
# Build currently booked list
|
||||
|
||||
if joined_rows:
|
||||
# TODO: (Future) calendar link
|
||||
# Get attendee counts for currently booked sessions
|
||||
rows = accountability_member_info.select_where(
|
||||
slotid=[row["slotid"] for row in joined_rows],
|
||||
userid=NOTNULL,
|
||||
select_columns=(
|
||||
'slotid',
|
||||
'guildid',
|
||||
'start_at',
|
||||
'COUNT(*) as num'
|
||||
),
|
||||
_extra="GROUP BY start_at, slotid, guildid ORDER BY start_at ASC"
|
||||
)
|
||||
attendees = {
|
||||
row['start_at']: (row['num'], row['guildid']) for row in rows
|
||||
}
|
||||
attendee_pad = max((len(str(num)) for num, _ in attendees.values()), default=1)
|
||||
|
||||
# TODO: Allow cancel to accept multiselect keys as args
|
||||
show_guild = any(guildid != ctx.guild.id for _, guildid in attendees.values())
|
||||
guild_map = {}
|
||||
if show_guild:
|
||||
for _, guildid in attendees.values():
|
||||
if guildid not in guild_map:
|
||||
guild = ctx.client.get_guild(guildid)
|
||||
if not guild:
|
||||
try:
|
||||
guild = await ctx.client.fetch_guild(guildid)
|
||||
except discord.HTTPException:
|
||||
guild = None
|
||||
guild_map[guildid] = guild
|
||||
|
||||
booked_list = '\n'.join(
|
||||
"`{:>{}}` attendees | {} {}".format(
|
||||
num,
|
||||
attendee_pad,
|
||||
time_format(start),
|
||||
"" if not show_guild else (
|
||||
"on this server" if guildid == ctx.guild.id else "in **{}**".format(
|
||||
guild_map[guildid] or "Unknown"
|
||||
)
|
||||
)
|
||||
) for start, (num, guildid) in attendees.items()
|
||||
)
|
||||
booked_field = (
|
||||
"{}\n\n"
|
||||
"*If you can't make your session, please cancel using `{}schedule cancel`!*"
|
||||
).format(booked_list, ctx.best_prefix)
|
||||
|
||||
# Temporary footer for acclimatisation
|
||||
# footer = "All times are displayed in your own timezone!"
|
||||
footer = "Book another session using {}schedule book".format(ctx.best_prefix)
|
||||
else:
|
||||
booked_field = (
|
||||
"Your schedule is empty!\n"
|
||||
"Book another session using `{}schedule book`."
|
||||
).format(ctx.best_prefix)
|
||||
footer = "Please keep your DMs open for notifications!"
|
||||
|
||||
# Finally, build embed
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=desc,
|
||||
).set_author(
|
||||
name="Schedule statistics for {}".format(ctx.author.name),
|
||||
icon_url=ctx.author.avatar_url
|
||||
).set_footer(
|
||||
text=footer,
|
||||
icon_url=hint_icon
|
||||
).add_field(
|
||||
name="Upcoming sessions",
|
||||
value=booked_field
|
||||
)
|
||||
|
||||
# And send it!
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
|
||||
# TODO: roomadmin
|
||||
34
src/modules/pending-rewrite/accountability/data.py
Normal file
34
src/modules/pending-rewrite/accountability/data.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from data import Table, RowTable
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
|
||||
accountability_rooms = RowTable(
|
||||
'accountability_slots',
|
||||
('slotid', 'channelid', 'guildid', 'start_at', 'messageid', 'closed_at'),
|
||||
'slotid',
|
||||
cache=TTLCache(5000, ttl=60*70),
|
||||
attach_as='accountability_rooms'
|
||||
)
|
||||
|
||||
|
||||
accountability_members = RowTable(
|
||||
'accountability_members',
|
||||
('slotid', 'userid', 'paid', 'duration', 'last_joined_at'),
|
||||
('slotid', 'userid'),
|
||||
cache=TTLCache(5000, ttl=60*70)
|
||||
)
|
||||
|
||||
accountability_member_info = Table('accountability_member_info')
|
||||
accountability_open_slots = Table('accountability_open_slots')
|
||||
|
||||
# @accountability_member_info.save_query
|
||||
# def user_streaks(userid, min_duration):
|
||||
# with accountability_member_info.conn as conn:
|
||||
# cursor = conn.cursor()
|
||||
# with cursor:
|
||||
# cursor.execute(
|
||||
# """
|
||||
|
||||
# """
|
||||
# )
|
||||
8
src/modules/pending-rewrite/accountability/lib.py
Normal file
8
src/modules/pending-rewrite/accountability/lib.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import datetime
|
||||
|
||||
|
||||
def utc_now():
|
||||
"""
|
||||
Return the current timezone-aware utc timestamp.
|
||||
"""
|
||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||
4
src/modules/pending-rewrite/accountability/module.py
Normal file
4
src/modules/pending-rewrite/accountability/module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from LionModule import LionModule
|
||||
|
||||
|
||||
module = LionModule("Accountability")
|
||||
515
src/modules/pending-rewrite/accountability/tracker.py
Normal file
515
src/modules/pending-rewrite/accountability/tracker.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import collections
|
||||
import traceback
|
||||
import logging
|
||||
import discord
|
||||
from typing import Dict
|
||||
from discord.utils import sleep_until
|
||||
|
||||
from meta import client
|
||||
from utils.interactive import discord_shield
|
||||
from data import NULL, NOTNULL, tables
|
||||
from data.conditions import LEQ, THIS_SHARD
|
||||
from settings import GuildSettings
|
||||
|
||||
from .TimeSlot import TimeSlot
|
||||
from .lib import utc_now
|
||||
from .data import accountability_rooms, accountability_members
|
||||
from .module import module
|
||||
|
||||
|
||||
voice_ignore_lock = asyncio.Lock()
|
||||
room_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def locker(lock):
|
||||
"""
|
||||
Function decorator to wrap the function in a provided Lock
|
||||
"""
|
||||
def decorator(func):
|
||||
async def wrapped(*args, **kwargs):
|
||||
async with lock:
|
||||
return await func(*args, **kwargs)
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
class AccountabilityGuild:
|
||||
__slots__ = ('guildid', 'current_slot', 'upcoming_slot')
|
||||
|
||||
cache: Dict[int, 'AccountabilityGuild'] = {} # Map guildid -> AccountabilityGuild
|
||||
|
||||
def __init__(self, guildid):
|
||||
self.guildid = guildid
|
||||
self.current_slot = None
|
||||
self.upcoming_slot = None
|
||||
|
||||
self.cache[guildid] = self
|
||||
|
||||
@property
|
||||
def guild(self):
|
||||
return client.get_guild(self.guildid)
|
||||
|
||||
@property
|
||||
def guild_settings(self):
|
||||
return GuildSettings(self.guildid)
|
||||
|
||||
def advance(self):
|
||||
self.current_slot = self.upcoming_slot
|
||||
self.upcoming_slot = None
|
||||
|
||||
|
||||
async def open_next(start_time):
|
||||
"""
|
||||
Open all the upcoming accountability rooms, and fire channel notify.
|
||||
To be executed ~5 minutes to the hour.
|
||||
"""
|
||||
# Pre-fetch the new slot data, also populating the table caches
|
||||
room_data = accountability_rooms.fetch_rows_where(
|
||||
start_at=start_time,
|
||||
guildid=THIS_SHARD
|
||||
)
|
||||
guild_rows = {row.guildid: row for row in room_data}
|
||||
member_data = accountability_members.fetch_rows_where(
|
||||
slotid=[row.slotid for row in room_data]
|
||||
) if room_data else []
|
||||
slot_memberids = collections.defaultdict(list)
|
||||
for row in member_data:
|
||||
slot_memberids[row.slotid].append(row.userid)
|
||||
|
||||
# Open a new slot in each accountability guild
|
||||
to_update = [] # Cache of slot update data to be applied at the end
|
||||
for aguild in list(AccountabilityGuild.cache.values()):
|
||||
guild = aguild.guild
|
||||
if guild:
|
||||
# Initialise next TimeSlot
|
||||
slot = TimeSlot(
|
||||
guild,
|
||||
start_time,
|
||||
data=guild_rows.get(aguild.guildid, None)
|
||||
)
|
||||
slot.load(memberids=slot_memberids[slot.data.slotid] if slot.data else None)
|
||||
|
||||
if not slot.category:
|
||||
# Log and unload guild
|
||||
aguild.guild_settings.event_log.log(
|
||||
"The scheduled session category couldn't be found!\n"
|
||||
"Shutting down the scheduled session system in this server.\n"
|
||||
"To re-activate, please reconfigure `config session_category`."
|
||||
)
|
||||
AccountabilityGuild.cache.pop(aguild.guildid, None)
|
||||
await slot.cancel()
|
||||
continue
|
||||
elif not slot.lobby:
|
||||
# TODO: Consider putting in TimeSlot.open().. or even better in accountability_lobby.create()
|
||||
# Create a new lobby
|
||||
try:
|
||||
channel = await guild.create_text_channel(
|
||||
name="session-lobby",
|
||||
category=slot.category,
|
||||
reason="Automatic creation of scheduled session lobby."
|
||||
)
|
||||
aguild.guild_settings.accountability_lobby.value = channel
|
||||
slot.lobby = channel
|
||||
except discord.HTTPException:
|
||||
# Event log failure and skip session
|
||||
aguild.guild_settings.event_log.log(
|
||||
"Failed to create the scheduled session lobby text channel.\n"
|
||||
"Please set the lobby channel manually with `config`."
|
||||
)
|
||||
await slot.cancel()
|
||||
continue
|
||||
|
||||
# Event log creation
|
||||
aguild.guild_settings.event_log.log(
|
||||
"Automatically created a scheduled session lobby channel {}.".format(channel.mention)
|
||||
)
|
||||
|
||||
results = await slot.open()
|
||||
if results is None:
|
||||
# Couldn't open the channel for some reason.
|
||||
# Should already have been logged in `open`.
|
||||
# Skip this session
|
||||
await slot.cancel()
|
||||
continue
|
||||
elif slot.data:
|
||||
to_update.append((results[0], results[1], slot.data.slotid))
|
||||
|
||||
# Time slot should now be open and ready to start
|
||||
aguild.upcoming_slot = slot
|
||||
else:
|
||||
# Unload guild from cache
|
||||
AccountabilityGuild.cache.pop(aguild.guildid, None)
|
||||
|
||||
# Update slot data
|
||||
if to_update:
|
||||
accountability_rooms.update_many(
|
||||
*to_update,
|
||||
set_keys=('channelid', 'messageid'),
|
||||
where_keys=('slotid',)
|
||||
)
|
||||
|
||||
|
||||
async def turnover():
|
||||
"""
|
||||
Switchover from the current accountability rooms to the next ones.
|
||||
To be executed as close as possible to the hour.
|
||||
"""
|
||||
now = utc_now()
|
||||
|
||||
# Open event lock so we don't read voice channel movement
|
||||
async with voice_ignore_lock:
|
||||
# Update session data for completed sessions
|
||||
last_slots = [
|
||||
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
|
||||
if aguild.current_slot is not None
|
||||
]
|
||||
|
||||
to_update = [
|
||||
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()), None, mem.slotid, mem.userid)
|
||||
for slot in last_slots for mem in slot.members.values()
|
||||
if mem.data and mem.data.last_joined_at
|
||||
]
|
||||
if to_update:
|
||||
accountability_members.update_many(
|
||||
*to_update,
|
||||
set_keys=('duration', 'last_joined_at'),
|
||||
where_keys=('slotid', 'userid'),
|
||||
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
||||
)
|
||||
|
||||
# Close all completed rooms, update data
|
||||
await asyncio.gather(*(slot.close() for slot in last_slots), return_exceptions=True)
|
||||
update_slots = [slot.data.slotid for slot in last_slots if slot.data]
|
||||
if update_slots:
|
||||
accountability_rooms.update_where(
|
||||
{'closed_at': utc_now()},
|
||||
slotid=update_slots
|
||||
)
|
||||
|
||||
# Rotate guild sessions
|
||||
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
|
||||
|
||||
# TODO: (FUTURE) with high volume, we might want to start the sessions before moving the members.
|
||||
# We could break up the session starting?
|
||||
|
||||
# ---------- Start next session ----------
|
||||
current_slots = [
|
||||
aguild.current_slot for aguild in AccountabilityGuild.cache.values()
|
||||
if aguild.current_slot is not None
|
||||
]
|
||||
slotmap = {slot.data.slotid: slot for slot in current_slots if slot.data}
|
||||
|
||||
# Reload the slot members in case they cancelled from another shard
|
||||
member_data = accountability_members.fetch_rows_where(
|
||||
slotid=list(slotmap.keys())
|
||||
) if slotmap else []
|
||||
slot_memberids = {slotid: [] for slotid in slotmap}
|
||||
for row in member_data:
|
||||
slot_memberids[row.slotid].append(row.userid)
|
||||
reload_tasks = (
|
||||
slot._reload_members(memberids=slot_memberids[slotid])
|
||||
for slotid, slot in slotmap.items()
|
||||
)
|
||||
await asyncio.gather(
|
||||
*reload_tasks,
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Move members of the next session over to the session channel
|
||||
movement_tasks = (
|
||||
mem.member.edit(
|
||||
voice_channel=slot.channel,
|
||||
reason="Moving to scheduled session."
|
||||
)
|
||||
for slot in current_slots
|
||||
for mem in slot.members.values()
|
||||
if mem.data and mem.member and mem.member.voice and mem.member.voice.channel != slot.channel
|
||||
)
|
||||
# We return exceptions here to ignore any permission issues that occur with moving members.
|
||||
# It's also possible (likely) that members will move while we are moving other members
|
||||
# Returning the exceptions ensures that they are explicitly ignored
|
||||
await asyncio.gather(
|
||||
*movement_tasks,
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Update session data of all members in new channels
|
||||
member_session_data = [
|
||||
(0, slot.start_time, mem.slotid, mem.userid)
|
||||
for slot in current_slots
|
||||
for mem in slot.members.values()
|
||||
if mem.data and mem.member and mem.member.voice and mem.member.voice.channel == slot.channel
|
||||
]
|
||||
if member_session_data:
|
||||
accountability_members.update_many(
|
||||
*member_session_data,
|
||||
set_keys=('duration', 'last_joined_at'),
|
||||
where_keys=('slotid', 'userid'),
|
||||
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
||||
)
|
||||
|
||||
# Start all the current rooms
|
||||
await asyncio.gather(
|
||||
*(slot.start() for slot in current_slots),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
|
||||
@client.add_after_event('voice_state_update')
|
||||
async def room_watchdog(client, member, before, after):
|
||||
"""
|
||||
Update session data when a member joins or leaves an accountability room.
|
||||
Ignores events that occur while `voice_ignore_lock` is held.
|
||||
"""
|
||||
if not voice_ignore_lock.locked() and before.channel != after.channel:
|
||||
aguild = AccountabilityGuild.cache.get(member.guild.id)
|
||||
if aguild and aguild.current_slot and aguild.current_slot.channel:
|
||||
slot = aguild.current_slot
|
||||
if member.id in slot.members:
|
||||
if after.channel and after.channel.id != slot.channel.id:
|
||||
# Summon them back!
|
||||
asyncio.create_task(member.edit(voice_channel=slot.channel))
|
||||
|
||||
slot_member = slot.members[member.id]
|
||||
data = slot_member.data
|
||||
|
||||
if before.channel and before.channel.id == slot.channel.id:
|
||||
# Left accountability room
|
||||
with data.batch_update():
|
||||
data.duration += int((utc_now() - data.last_joined_at).total_seconds())
|
||||
data.last_joined_at = None
|
||||
await slot.update_status()
|
||||
elif after.channel and after.channel.id == slot.channel.id:
|
||||
# Joined accountability room
|
||||
with data.batch_update():
|
||||
data.last_joined_at = utc_now()
|
||||
await slot.update_status()
|
||||
|
||||
|
||||
async def _accountability_loop():
|
||||
"""
|
||||
Runloop in charge of executing the room update tasks at the correct times.
|
||||
"""
|
||||
# Wait until ready
|
||||
while not client.is_ready():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Calculate starting next_time
|
||||
# Assume the resume logic has taken care of all events/tasks before current_time
|
||||
now = utc_now()
|
||||
if now.minute < 55:
|
||||
next_time = now.replace(minute=55, second=0, microsecond=0)
|
||||
else:
|
||||
next_time = now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
|
||||
|
||||
# Executor loop
|
||||
while True:
|
||||
# TODO: (FUTURE) handle cases where we actually execute much late than expected
|
||||
await sleep_until(next_time)
|
||||
if next_time.minute == 55:
|
||||
next_time = next_time + datetime.timedelta(minutes=5)
|
||||
# Open next sessions
|
||||
try:
|
||||
async with room_lock:
|
||||
await open_next(next_time)
|
||||
except Exception:
|
||||
# Unknown exception. Catch it so the loop doesn't die.
|
||||
client.log(
|
||||
"Error while opening new scheduled sessions! "
|
||||
"Exception traceback follows.\n{}".format(
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="ACCOUNTABILITY_LOOP",
|
||||
level=logging.ERROR
|
||||
)
|
||||
elif next_time.minute == 0:
|
||||
# Start new sessions
|
||||
try:
|
||||
async with room_lock:
|
||||
await turnover()
|
||||
except Exception:
|
||||
# Unknown exception. Catch it so the loop doesn't die.
|
||||
client.log(
|
||||
"Error while starting scheduled sessions! "
|
||||
"Exception traceback follows.\n{}".format(
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="ACCOUNTABILITY_LOOP",
|
||||
level=logging.ERROR
|
||||
)
|
||||
next_time = next_time + datetime.timedelta(minutes=55)
|
||||
|
||||
|
||||
async def _accountability_system_resume():
|
||||
"""
|
||||
Logic for starting the accountability system from cold.
|
||||
Essentially, session and state resume logic.
|
||||
"""
|
||||
now = utc_now()
|
||||
|
||||
# Fetch the open room data, only takes into account currently running sessions.
|
||||
# May include sessions that were never opened, or opened but never started
|
||||
# Does not include sessions that were opened that start on the next hour
|
||||
open_room_data = accountability_rooms.fetch_rows_where(
|
||||
closed_at=NULL,
|
||||
start_at=LEQ(now),
|
||||
guildid=THIS_SHARD,
|
||||
_extra="ORDER BY start_at ASC"
|
||||
)
|
||||
|
||||
if open_room_data:
|
||||
# Extract member data of these rows
|
||||
member_data = accountability_members.fetch_rows_where(
|
||||
slotid=[row.slotid for row in open_room_data]
|
||||
)
|
||||
slot_members = collections.defaultdict(list)
|
||||
for row in member_data:
|
||||
slot_members[row.slotid].append(row)
|
||||
|
||||
# Filter these into expired rooms and current rooms
|
||||
expired_room_data = []
|
||||
current_room_data = []
|
||||
for row in open_room_data:
|
||||
if row.start_at + datetime.timedelta(hours=1) < now:
|
||||
expired_room_data.append(row)
|
||||
else:
|
||||
current_room_data.append(row)
|
||||
|
||||
session_updates = []
|
||||
|
||||
# TODO URGENT: Batch room updates here
|
||||
|
||||
# Expire the expired rooms
|
||||
for row in expired_room_data:
|
||||
if row.channelid is None or row.messageid is None:
|
||||
# TODO refunds here
|
||||
# If the rooms were never opened, close them and skip
|
||||
row.closed_at = now
|
||||
else:
|
||||
# If the rooms were opened and maybe started, make optimistic guesses on session data and close.
|
||||
session_end = row.start_at + datetime.timedelta(hours=1)
|
||||
session_updates.extend(
|
||||
(mow.duration + int((session_end - mow.last_joined_at).total_seconds()),
|
||||
None, mow.slotid, mow.userid)
|
||||
for mow in slot_members[row.slotid] if mow.last_joined_at
|
||||
)
|
||||
if client.get_guild(row.guildid):
|
||||
slot = TimeSlot(client.get_guild(row.guildid), row.start_at, data=row).load(
|
||||
memberids=[mow.userid for mow in slot_members[row.slotid]]
|
||||
)
|
||||
try:
|
||||
await slot.close()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
row.closed_at = now
|
||||
|
||||
# Load the in-progress room data
|
||||
if current_room_data:
|
||||
async with voice_ignore_lock:
|
||||
current_hour = now.replace(minute=0, second=0, microsecond=0)
|
||||
await open_next(current_hour)
|
||||
[aguild.advance() for aguild in AccountabilityGuild.cache.values()]
|
||||
|
||||
current_slots = [
|
||||
aguild.current_slot
|
||||
for aguild in AccountabilityGuild.cache.values()
|
||||
if aguild.current_slot
|
||||
]
|
||||
|
||||
session_updates.extend(
|
||||
(mem.data.duration + int((now - mem.data.last_joined_at).total_seconds()),
|
||||
None, mem.slotid, mem.userid)
|
||||
for slot in current_slots
|
||||
for mem in slot.members.values()
|
||||
if mem.data.last_joined_at and mem.member not in slot.channel.members
|
||||
)
|
||||
|
||||
session_updates.extend(
|
||||
(mem.data.duration,
|
||||
now, mem.slotid, mem.userid)
|
||||
for slot in current_slots
|
||||
for mem in slot.members.values()
|
||||
if not mem.data.last_joined_at and mem.member in slot.channel.members
|
||||
)
|
||||
|
||||
if session_updates:
|
||||
accountability_members.update_many(
|
||||
*session_updates,
|
||||
set_keys=('duration', 'last_joined_at'),
|
||||
where_keys=('slotid', 'userid'),
|
||||
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*(aguild.current_slot.start()
|
||||
for aguild in AccountabilityGuild.cache.values() if aguild.current_slot)
|
||||
)
|
||||
else:
|
||||
if session_updates:
|
||||
accountability_members.update_many(
|
||||
*session_updates,
|
||||
set_keys=('duration', 'last_joined_at'),
|
||||
where_keys=('slotid', 'userid'),
|
||||
cast_row='(NULL::int, NULL::timestamptz, NULL::int, NULL::int)'
|
||||
)
|
||||
|
||||
# If we are in the last five minutes of the hour, open new rooms.
|
||||
# Note that these may already have been opened, or they may not have been.
|
||||
if now.minute >= 55:
|
||||
await open_next(
|
||||
now.replace(minute=0, second=0, microsecond=0) + datetime.timedelta(hours=1)
|
||||
)
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def launch_accountability_system(client):
|
||||
"""
|
||||
Launcher for the accountability system.
|
||||
Resumes saved sessions, and starts the accountability loop.
|
||||
"""
|
||||
# Load the AccountabilityGuild cache
|
||||
guilds = tables.guild_config.fetch_rows_where(
|
||||
accountability_category=NOTNULL,
|
||||
guildid=THIS_SHARD
|
||||
)
|
||||
# Further filter out any guilds that we aren't in
|
||||
[AccountabilityGuild(guild.guildid) for guild in guilds if client.get_guild(guild.guildid)]
|
||||
await _accountability_system_resume()
|
||||
asyncio.create_task(_accountability_loop())
|
||||
|
||||
|
||||
async def unload_accountability(client):
|
||||
"""
|
||||
Save the current sessions and cancel the runloop in preparation for client shutdown.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@client.add_after_event('member_join')
|
||||
async def restore_accountability(client, member):
|
||||
"""
|
||||
Restore accountability channel permissions when a member rejoins the server, if applicable.
|
||||
"""
|
||||
aguild = AccountabilityGuild.cache.get(member.guild.id, None)
|
||||
if aguild:
|
||||
if aguild.current_slot and member.id in aguild.current_slot.members:
|
||||
# Restore member permission for current slot
|
||||
slot = aguild.current_slot
|
||||
if slot.channel:
|
||||
asyncio.create_task(discord_shield(
|
||||
slot.channel.set_permissions(
|
||||
member,
|
||||
overwrite=slot._member_overwrite
|
||||
)
|
||||
))
|
||||
if aguild.upcoming_slot and member.id in aguild.upcoming_slot.members:
|
||||
slot = aguild.upcoming_slot
|
||||
if slot.channel:
|
||||
asyncio.create_task(discord_shield(
|
||||
slot.channel.set_permissions(
|
||||
member,
|
||||
overwrite=slot._member_overwrite
|
||||
)
|
||||
))
|
||||
0
src/modules/pending-rewrite/guide/__init__.py
Normal file
0
src/modules/pending-rewrite/guide/__init__.py
Normal file
7
src/modules/pending-rewrite/guild_admin/__init__.py
Normal file
7
src/modules/pending-rewrite/guild_admin/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .module import module
|
||||
|
||||
from . import guild_config
|
||||
from . import statreset
|
||||
from . import new_members
|
||||
from . import reaction_roles
|
||||
from . import economy
|
||||
@@ -0,0 +1,3 @@
|
||||
from ..module import module
|
||||
|
||||
from . import set_coins
|
||||
104
src/modules/pending-rewrite/guild_admin/economy/set_coins.py
Normal file
104
src/modules/pending-rewrite/guild_admin/economy/set_coins.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import discord
|
||||
import datetime
|
||||
from wards import guild_admin
|
||||
|
||||
from settings import GuildSettings
|
||||
from core import Lion
|
||||
|
||||
from ..module import module
|
||||
|
||||
POSTGRES_INT_MAX = 2147483647
|
||||
|
||||
@module.cmd(
|
||||
"set_coins",
|
||||
group="Guild Admin",
|
||||
desc="Set coins on a member."
|
||||
)
|
||||
@guild_admin()
|
||||
async def cmd_set(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}set_coins <user mention> <amount>
|
||||
Description:
|
||||
Sets the given number of coins on the mentioned user.
|
||||
If a number greater than 0 is mentioned, will add coins.
|
||||
If a number less than 0 is mentioned, will remove coins.
|
||||
Note: LionCoins on a member cannot be negative.
|
||||
Example:
|
||||
{prefix}set_coins {ctx.author.mention} 100
|
||||
{prefix}set_coins {ctx.author.mention} -100
|
||||
"""
|
||||
# Extract target and amount
|
||||
# Handle a slightly more flexible input than stated
|
||||
splits = ctx.args.split()
|
||||
digits = [isNumber(split) for split in splits[:2]]
|
||||
mentions = ctx.msg.mentions
|
||||
if len(splits) < 2 or not any(digits) or not (all(digits) or mentions):
|
||||
return await _send_usage(ctx)
|
||||
|
||||
if all(digits):
|
||||
# Both are digits, hopefully one is a member id, and one is an amount.
|
||||
target, amount = ctx.guild.get_member(int(splits[0])), int(splits[1])
|
||||
if not target:
|
||||
amount, target = int(splits[0]), ctx.guild.get_member(int(splits[1]))
|
||||
if not target:
|
||||
return await _send_usage(ctx)
|
||||
elif digits[0]:
|
||||
amount, target = int(splits[0]), mentions[0]
|
||||
elif digits[1]:
|
||||
target, amount = mentions[0], int(splits[1])
|
||||
|
||||
# Fetch the associated lion
|
||||
target_lion = Lion.fetch(ctx.guild.id, target.id)
|
||||
|
||||
# Check sanity conditions
|
||||
if target == ctx.client.user:
|
||||
return await ctx.embed_reply("Thanks, but Ari looks after all my needs!")
|
||||
if target.bot:
|
||||
return await ctx.embed_reply("We are still waiting for {} to open an account.".format(target.mention))
|
||||
|
||||
# Finally, send the amount and the ack message
|
||||
# Postgres `coins` column is `integer`, sanity check postgres int limits - which are smalled than python int range
|
||||
target_coins_to_set = target_lion.coins + amount
|
||||
if target_coins_to_set >= 0 and target_coins_to_set <= POSTGRES_INT_MAX:
|
||||
target_lion.addCoins(amount)
|
||||
elif target_coins_to_set < 0:
|
||||
target_coins_to_set = -target_lion.coins # Coins cannot go -ve, cap to 0
|
||||
target_lion.addCoins(target_coins_to_set)
|
||||
target_coins_to_set = 0
|
||||
else:
|
||||
return await ctx.embed_reply("Member coins cannot be more than {}".format(POSTGRES_INT_MAX))
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Funds Set",
|
||||
description="You have set LionCoins on {} to **{}**!".format(target.mention,target_coins_to_set),
|
||||
colour=discord.Colour.orange(),
|
||||
timestamp=datetime.datetime.utcnow()
|
||||
).set_footer(text=str(ctx.author), icon_url=ctx.author.avatar_url)
|
||||
|
||||
await ctx.reply(embed=embed, reference=ctx.msg)
|
||||
GuildSettings(ctx.guild.id).event_log.log(
|
||||
"{} set {}'s LionCoins to`{}`.".format(
|
||||
ctx.author.mention,
|
||||
target.mention,
|
||||
target_coins_to_set
|
||||
),
|
||||
title="Funds Set"
|
||||
)
|
||||
|
||||
def isNumber(var):
|
||||
try:
|
||||
return isinstance(int(var), int)
|
||||
except:
|
||||
return False
|
||||
|
||||
async def _send_usage(ctx):
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{prefix}set_coins <mention> <amount>`\n"
|
||||
"**Example:**\n"
|
||||
" {prefix}set_coins {ctx.author.mention} 100\n"
|
||||
" {prefix}set_coins {ctx.author.mention} -100".format(
|
||||
prefix=ctx.best_prefix,
|
||||
ctx=ctx
|
||||
)
|
||||
)
|
||||
163
src/modules/pending-rewrite/guild_admin/guild_config.py
Normal file
163
src/modules/pending-rewrite/guild_admin/guild_config.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import difflib
|
||||
import discord
|
||||
from cmdClient.lib import SafeCancellation
|
||||
|
||||
from wards import guild_admin, guild_moderator
|
||||
from settings import UserInputError, GuildSettings
|
||||
|
||||
from utils.lib import prop_tabulate
|
||||
import utils.ctx_addons # noqa
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
# Pages of configuration categories to display
|
||||
cat_pages = {
|
||||
'Administration': ('Meta', 'Guild Roles', 'New Members'),
|
||||
'Moderation': ('Moderation', 'Video Channels'),
|
||||
'Productivity': ('Study Tracking', 'TODO List', 'Workout'),
|
||||
'Study Rooms': ('Rented Rooms', 'Scheduled Sessions'),
|
||||
}
|
||||
|
||||
# Descriptions of each configuration category
|
||||
descriptions = {
|
||||
}
|
||||
|
||||
|
||||
@module.cmd("config",
|
||||
desc="View and modify the server settings.",
|
||||
flags=('add', 'remove'),
|
||||
group="Guild Configuration")
|
||||
@guild_moderator()
|
||||
async def cmd_config(ctx, flags):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}config
|
||||
{prefix}config info
|
||||
{prefix}config <setting>
|
||||
{prefix}config <setting> <value>
|
||||
Description:
|
||||
Display the server configuration panel, and view/modify the server settings.
|
||||
|
||||
Use `{prefix}config` to see the settings with their current values, or `{prefix}config info` to \
|
||||
show brief descriptions instead.
|
||||
Use `{prefix}config <setting>` (e.g. `{prefix}config event_log`) to view a more detailed description for each setting, \
|
||||
including the possible values.
|
||||
Finally, use `{prefix}config <setting> <value>` to set the setting to the given value.
|
||||
To unset a setting, or set it to the default, use `{prefix}config <setting> None`.
|
||||
|
||||
Additional usage for settings which accept a list of values:
|
||||
`{prefix}config <setting> <value1>, <value2>, ...`
|
||||
`{prefix}config <setting> --add <value1>, <value2>, ...`
|
||||
`{prefix}config <setting> --remove <value1>, <value2>, ...`
|
||||
Note that the first form *overwrites* the setting completely,\
|
||||
while the second two will only *add* and *remove* values, respectively.
|
||||
Examples``:
|
||||
{prefix}config event_log
|
||||
{prefix}config event_log {ctx.ch.name}
|
||||
{prefix}config autoroles Member, Level 0, Level 10
|
||||
{prefix}config autoroles --remove Level 10
|
||||
"""
|
||||
# Cache and map some info for faster access
|
||||
setting_displaynames = {setting.display_name.lower(): setting for setting in GuildSettings.settings.values()}
|
||||
|
||||
if not ctx.args or ctx.args.lower() in ('info', 'help'):
|
||||
# Fill the setting cats
|
||||
cats = {}
|
||||
for setting in GuildSettings.settings.values():
|
||||
cat = cats.get(setting.category, [])
|
||||
cat.append(setting)
|
||||
cats[setting.category] = cat
|
||||
|
||||
# Format the cats
|
||||
sections = {}
|
||||
for catname, cat in cats.items():
|
||||
catprops = {
|
||||
setting.display_name: setting.get(ctx.guild.id).summary if not ctx.args else setting.desc
|
||||
for setting in cat
|
||||
}
|
||||
# TODO: Add cat description here
|
||||
sections[catname] = prop_tabulate(*zip(*catprops.items()))
|
||||
|
||||
# Put the cats on the correct pages
|
||||
pages = []
|
||||
for page_name, cat_names in cat_pages.items():
|
||||
page = {
|
||||
cat_name: sections[cat_name] for cat_name in cat_names if cat_name in sections
|
||||
}
|
||||
if page:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
title=page_name,
|
||||
description=(
|
||||
"View brief setting descriptions with `{prefix}config info`.\n"
|
||||
"Use e.g. `{prefix}config event_log` to see more details.\n"
|
||||
"Modify a setting with e.g. `{prefix}config event_log {ctx.ch.name}`.\n"
|
||||
"See the [Online Tutorial]({tutorial}) for a complete setup guide.".format(
|
||||
prefix=ctx.best_prefix,
|
||||
ctx=ctx,
|
||||
tutorial="https://discord.studylions.com/tutorial"
|
||||
)
|
||||
)
|
||||
)
|
||||
for name, value in page.items():
|
||||
embed.add_field(name=name, value=value, inline=False)
|
||||
|
||||
pages.append(embed)
|
||||
|
||||
if len(pages) > 1:
|
||||
[
|
||||
embed.set_footer(text="Page {} of {}".format(i+1, len(pages)))
|
||||
for i, embed in enumerate(pages)
|
||||
]
|
||||
await ctx.pager(pages)
|
||||
elif pages:
|
||||
await ctx.reply(embed=pages[0])
|
||||
else:
|
||||
await ctx.reply("No configuration options set up yet!")
|
||||
else:
|
||||
# Some args were given
|
||||
parts = ctx.args.split(maxsplit=1)
|
||||
|
||||
name = parts[0]
|
||||
setting = setting_displaynames.get(name.lower(), None)
|
||||
if setting is None:
|
||||
matches = difflib.get_close_matches(name, setting_displaynames.keys(), n=2)
|
||||
match = "`{}`".format('` or `'.join(matches)) if matches else None
|
||||
return await ctx.error_reply(
|
||||
"Couldn't find a setting called `{}`!\n"
|
||||
"{}"
|
||||
"Use `{}config info` to see all the server settings.".format(
|
||||
name,
|
||||
"Maybe you meant {}?\n".format(match) if match else "",
|
||||
ctx.best_prefix
|
||||
)
|
||||
)
|
||||
|
||||
if len(parts) == 1 and not ctx.msg.attachments:
|
||||
# config <setting>
|
||||
# View config embed for provided setting
|
||||
await setting.get(ctx.guild.id).widget(ctx, flags=flags)
|
||||
else:
|
||||
# config <setting> <value>
|
||||
# Ignoring the write ward currently and just enforcing admin
|
||||
# Check the write ward
|
||||
# if not await setting.write_ward.run(ctx):
|
||||
# raise SafeCancellation(setting.write_ward.msg)
|
||||
if not await guild_admin.run(ctx):
|
||||
raise SafeCancellation("You need to be a server admin to modify settings!")
|
||||
|
||||
# Attempt to set config setting
|
||||
try:
|
||||
parsed = await setting.parse(ctx.guild.id, ctx, parts[1] if len(parts) > 1 else '')
|
||||
parsed.write(add_only=flags['add'], remove_only=flags['remove'])
|
||||
except UserInputError as e:
|
||||
await ctx.reply(embed=discord.Embed(
|
||||
description="{} {}".format('❌', e.msg),
|
||||
colour=discord.Colour.red()
|
||||
))
|
||||
else:
|
||||
await ctx.reply(embed=discord.Embed(
|
||||
description="{} {}".format('✅', setting.get(ctx.guild.id).success_response),
|
||||
colour=discord.Colour.green()
|
||||
))
|
||||
4
src/modules/pending-rewrite/guild_admin/module.py
Normal file
4
src/modules/pending-rewrite/guild_admin/module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from LionModule import LionModule
|
||||
|
||||
|
||||
module = LionModule("Guild_Admin")
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import settings
|
||||
from . import greetings
|
||||
from . import roles
|
||||
@@ -0,0 +1,6 @@
|
||||
from data import Table, RowTable
|
||||
|
||||
|
||||
autoroles = Table('autoroles')
|
||||
bot_autoroles = Table('bot_autoroles')
|
||||
past_member_roles = Table('past_member_roles')
|
||||
@@ -0,0 +1,29 @@
|
||||
import discord
|
||||
from LionContext import LionContext as Context
|
||||
|
||||
from meta import client
|
||||
|
||||
from .settings import greeting_message, greeting_channel, returning_message
|
||||
|
||||
|
||||
@client.add_after_event('member_join')
|
||||
async def send_greetings(client, member):
|
||||
guild = member.guild
|
||||
|
||||
returning = bool(client.data.lions.fetch((guild.id, member.id)))
|
||||
|
||||
# Handle greeting message
|
||||
channel = greeting_channel.get(guild.id).value
|
||||
if channel is not None:
|
||||
if channel == greeting_channel.DMCHANNEL:
|
||||
channel = member
|
||||
|
||||
ctx = Context(client, guild=guild, author=member)
|
||||
if returning:
|
||||
args = returning_message.get(guild.id).args(ctx)
|
||||
else:
|
||||
args = greeting_message.get(guild.id).args(ctx)
|
||||
try:
|
||||
await channel.send(**args)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
115
src/modules/pending-rewrite/guild_admin/new_members/roles.py
Normal file
115
src/modules/pending-rewrite/guild_admin/new_members/roles.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from collections import defaultdict
|
||||
|
||||
from meta import client
|
||||
from core import Lion
|
||||
from settings import GuildSettings
|
||||
|
||||
from .settings import autoroles, bot_autoroles, role_persistence
|
||||
from .data import past_member_roles
|
||||
|
||||
|
||||
# Locks to avoid storing the roles while adding them
|
||||
# The locking is cautious, leaving data unchanged upon collision
|
||||
locks = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
@client.add_after_event('member_join')
|
||||
async def join_role_tracker(client, member):
|
||||
"""
|
||||
Add autoroles or saved roles as needed.
|
||||
"""
|
||||
guild = member.guild
|
||||
if not guild.me.guild_permissions.manage_roles:
|
||||
# We can't manage the roles here, don't try to give/restore the member roles
|
||||
return
|
||||
|
||||
async with locks[(guild.id, member.id)]:
|
||||
if role_persistence.get(guild.id).value and client.data.lions.fetch((guild.id, member.id)):
|
||||
# Lookup stored roles
|
||||
role_rows = past_member_roles.select_where(
|
||||
guildid=guild.id,
|
||||
userid=member.id
|
||||
)
|
||||
# Identify roles from roleids
|
||||
roles = (guild.get_role(row['roleid']) for row in role_rows)
|
||||
# Remove non-existent roles
|
||||
roles = (role for role in roles if role is not None)
|
||||
# Remove roles the client can't add
|
||||
roles = [role for role in roles if role < guild.me.top_role]
|
||||
if roles:
|
||||
try:
|
||||
await member.add_roles(
|
||||
*roles,
|
||||
reason="Restoring saved roles.",
|
||||
)
|
||||
except discord.HTTPException:
|
||||
# This shouldn't ususally happen, but there are valid cases where it can
|
||||
# E.g. the user left while we were restoring their roles
|
||||
pass
|
||||
# Event log!
|
||||
GuildSettings(guild.id).event_log.log(
|
||||
"Restored the following roles for returning member {}:\n{}".format(
|
||||
member.mention,
|
||||
', '.join(role.mention for role in roles)
|
||||
),
|
||||
title="Saved roles restored"
|
||||
)
|
||||
else:
|
||||
# Add autoroles
|
||||
roles = bot_autoroles.get(guild.id).value if member.bot else autoroles.get(guild.id).value
|
||||
# Remove roles the client can't add
|
||||
roles = [role for role in roles if role < guild.me.top_role]
|
||||
if roles:
|
||||
try:
|
||||
await member.add_roles(
|
||||
*roles,
|
||||
reason="Adding autoroles.",
|
||||
)
|
||||
except discord.HTTPException:
|
||||
# This shouldn't ususally happen, but there are valid cases where it can
|
||||
# E.g. the user left while we were adding autoroles
|
||||
pass
|
||||
# Event log!
|
||||
GuildSettings(guild.id).event_log.log(
|
||||
"Gave {} the guild autoroles:\n{}".format(
|
||||
member.mention,
|
||||
', '.join(role.mention for role in roles)
|
||||
),
|
||||
titles="Autoroles added"
|
||||
)
|
||||
|
||||
|
||||
@client.add_after_event('member_remove')
|
||||
async def left_role_tracker(client, member):
|
||||
"""
|
||||
Delete and re-store member roles when they leave the server.
|
||||
"""
|
||||
if (member.guild.id, member.id) in locks and locks[(member.guild.id, member.id)].locked():
|
||||
# Currently processing a join event
|
||||
# Which means the member left while we were adding their roles
|
||||
# Cautiously return, not modifying the saved role data
|
||||
return
|
||||
|
||||
# Delete existing member roles for this user
|
||||
# NOTE: Not concurrency-safe
|
||||
past_member_roles.delete_where(
|
||||
guildid=member.guild.id,
|
||||
userid=member.id,
|
||||
)
|
||||
if role_persistence.get(member.guild.id).value:
|
||||
# Make sure the user has an associated lion, so we can detect when they rejoin
|
||||
Lion.fetch(member.guild.id, member.id)
|
||||
|
||||
# Then insert the current member roles
|
||||
values = [
|
||||
(member.guild.id, member.id, role.id)
|
||||
for role in member.roles
|
||||
if not role.is_bot_managed() and not role.is_integration() and not role.is_default()
|
||||
]
|
||||
if values:
|
||||
past_member_roles.insert_many(
|
||||
*values,
|
||||
insert_keys=('guildid', 'userid', 'roleid')
|
||||
)
|
||||
303
src/modules/pending-rewrite/guild_admin/new_members/settings.py
Normal file
303
src/modules/pending-rewrite/guild_admin/new_members/settings.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import datetime
|
||||
import discord
|
||||
|
||||
import settings
|
||||
from settings import GuildSettings, GuildSetting
|
||||
import settings.setting_types as stypes
|
||||
from wards import guild_admin
|
||||
|
||||
from .data import autoroles, bot_autoroles
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class greeting_channel(stypes.Channel, GuildSetting):
|
||||
"""
|
||||
Setting describing the destination of the greeting message.
|
||||
|
||||
Extended to support the following special values, with input and output supported.
|
||||
Data `None` corresponds to `Off`.
|
||||
Data `1` corresponds to `DM`.
|
||||
"""
|
||||
DMCHANNEL = object()
|
||||
|
||||
category = "New Members"
|
||||
|
||||
attr_name = 'greeting_channel'
|
||||
_data_column = 'greeting_channel'
|
||||
|
||||
display_name = "welcome_channel"
|
||||
desc = "Channel to send the welcome message in"
|
||||
|
||||
long_desc = (
|
||||
"Channel to post the `welcome_message` in when a new user joins the server. "
|
||||
"Accepts `DM` to indicate the welcome should be sent via direct message."
|
||||
)
|
||||
_accepts = (
|
||||
"Text Channel name/id/mention, or `DM`, or `None` to disable."
|
||||
)
|
||||
_chan_type = discord.ChannelType.text
|
||||
|
||||
@classmethod
|
||||
def _data_to_value(cls, id, data, **kwargs):
|
||||
if data is None:
|
||||
return None
|
||||
elif data == 1:
|
||||
return cls.DMCHANNEL
|
||||
else:
|
||||
return super()._data_to_value(id, data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _data_from_value(cls, id, value, **kwargs):
|
||||
if value is None:
|
||||
return None
|
||||
elif value == cls.DMCHANNEL:
|
||||
return 1
|
||||
else:
|
||||
return super()._data_from_value(id, value, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _parse_userstr(cls, ctx, id, userstr, **kwargs):
|
||||
lower = userstr.lower()
|
||||
if lower in ('0', 'none', 'off'):
|
||||
return None
|
||||
elif lower == 'dm':
|
||||
return 1
|
||||
else:
|
||||
return await super()._parse_userstr(ctx, id, userstr, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id, data, **kwargs):
|
||||
if data is None:
|
||||
return "Off"
|
||||
elif data == 1:
|
||||
return "DM"
|
||||
else:
|
||||
return "<#{}>".format(data)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
value = self.value
|
||||
if not value:
|
||||
return "Welcome messages are disabled."
|
||||
elif value == self.DMCHANNEL:
|
||||
return "Welcome messages will be sent via direct message."
|
||||
else:
|
||||
return "Welcome messages will be posted in {}".format(self.formatted)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class greeting_message(stypes.Message, GuildSetting):
|
||||
category = "New Members"
|
||||
|
||||
attr_name = 'greeting_message'
|
||||
_data_column = 'greeting_message'
|
||||
|
||||
display_name = 'welcome_message'
|
||||
desc = "Welcome message sent to welcome new members."
|
||||
|
||||
long_desc = (
|
||||
"Message to send to the configured `welcome_channel` when a member joins the server for the first time."
|
||||
)
|
||||
|
||||
_default = r"""
|
||||
{
|
||||
"embed": {
|
||||
"title": "Welcome!",
|
||||
"thumbnail": {"url": "{guild_icon}"},
|
||||
"description": "Hi {mention}!\nWelcome to **{guild_name}**! You are the **{member_count}**th member.\nThere are currently **{studying_count}** people studying.\nGood luck and stay productive!",
|
||||
"color": 15695665
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
_substitution_desc = {
|
||||
'{mention}': "Mention the new member.",
|
||||
'{user_name}': "Username of the new member.",
|
||||
'{user_avatar}': "Avatar of the new member.",
|
||||
'{guild_name}': "Name of this server.",
|
||||
'{guild_icon}': "Server icon url.",
|
||||
'{member_count}': "Number of members in the server.",
|
||||
'{studying_count}': "Number of current voice channel members.",
|
||||
}
|
||||
|
||||
def substitution_keys(self, ctx, **kwargs):
|
||||
return {
|
||||
'{mention}': ctx.author.mention,
|
||||
'{user_name}': ctx.author.name,
|
||||
'{user_avatar}': str(ctx.author.avatar_url),
|
||||
'{guild_name}': ctx.guild.name,
|
||||
'{guild_icon}': str(ctx.guild.icon_url),
|
||||
'{member_count}': str(len(ctx.guild.members)),
|
||||
'{studying_count}': str(len([member for ch in ctx.guild.voice_channels for member in ch.members]))
|
||||
}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "The welcome message has been set!"
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class returning_message(stypes.Message, GuildSetting):
|
||||
category = "New Members"
|
||||
|
||||
attr_name = 'returning_message'
|
||||
_data_column = 'returning_message'
|
||||
|
||||
display_name = 'returning_message'
|
||||
desc = "Welcome message sent to returning members."
|
||||
|
||||
long_desc = (
|
||||
"Message to send to the configured `welcome_channel` when a member returns to the server."
|
||||
)
|
||||
|
||||
_default = r"""
|
||||
{
|
||||
"embed": {
|
||||
"title": "Welcome Back {user_name}!",
|
||||
"thumbnail": {"url": "{guild_icon}"},
|
||||
"description": "Welcome back to **{guild_name}**!\nYou last studied with us <t:{last_time}:R>.\nThere are currently **{studying_count}** people studying.\nGood luck and stay productive!",
|
||||
"color": 15695665
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
_substitution_desc = {
|
||||
'{mention}': "Mention the returning member.",
|
||||
'{user_name}': "Username of the member.",
|
||||
'{user_avatar}': "Avatar of the member.",
|
||||
'{guild_name}': "Name of this server.",
|
||||
'{guild_icon}': "Server icon url.",
|
||||
'{member_count}': "Number of members in the server.",
|
||||
'{studying_count}': "Number of current voice channel members.",
|
||||
'{last_time}': "Unix timestamp of the last time the member studied.",
|
||||
}
|
||||
|
||||
def substitution_keys(self, ctx, **kwargs):
|
||||
return {
|
||||
'{mention}': ctx.author.mention,
|
||||
'{user_name}': ctx.author.name,
|
||||
'{user_avatar}': str(ctx.author.avatar_url),
|
||||
'{guild_name}': ctx.guild.name,
|
||||
'{guild_icon}': str(ctx.guild.icon_url),
|
||||
'{member_count}': str(len(ctx.guild.members)),
|
||||
'{studying_count}': str(len([member for ch in ctx.guild.voice_channels for member in ch.members])),
|
||||
'{last_time}': int(ctx.alion.data._timestamp.replace(tzinfo=datetime.timezone.utc).timestamp()),
|
||||
}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "The returning message has been set!"
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class starting_funds(stypes.Integer, GuildSetting):
|
||||
category = "New Members"
|
||||
|
||||
attr_name = 'starting_funds'
|
||||
_data_column = 'starting_funds'
|
||||
|
||||
display_name = 'starting_funds'
|
||||
desc = "Coins given when a user first joins."
|
||||
|
||||
long_desc = (
|
||||
"Members will be given this number of coins the first time they join the server."
|
||||
)
|
||||
|
||||
_default = 1000
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return "Members will be given `{}` coins when they first join the server.".format(self.formatted)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class autoroles(stypes.RoleList, settings.ListData, settings.Setting):
|
||||
category = "New Members"
|
||||
write_ward = guild_admin
|
||||
|
||||
attr_name = 'autoroles'
|
||||
|
||||
_table_interface = autoroles
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'roleid'
|
||||
|
||||
display_name = "autoroles"
|
||||
desc = "Roles to give automatically to new members."
|
||||
|
||||
_force_unique = True
|
||||
|
||||
long_desc = (
|
||||
"These roles will be given automatically to users when they join the server. "
|
||||
"If `role_persistence` is enabled, the roles will only be given the first time a user joins the server."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "New members will be given the following roles:\n{}".format(self.formatted)
|
||||
else:
|
||||
return "New members will not automatically be given any roles."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class bot_autoroles(stypes.RoleList, settings.ListData, settings.Setting):
|
||||
category = "New Members"
|
||||
write_ward = guild_admin
|
||||
|
||||
attr_name = 'bot_autoroles'
|
||||
|
||||
_table_interface = bot_autoroles
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'roleid'
|
||||
|
||||
display_name = "bot_autoroles"
|
||||
desc = "Roles to give automatically to new bots."
|
||||
|
||||
_force_unique = True
|
||||
|
||||
long_desc = (
|
||||
"These roles will be given automatically to bots when they join the server. "
|
||||
"If `role_persistence` is enabled, the roles will only be given the first time a bot joins the server."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "New bots will be given the following roles:\n{}".format(self.formatted)
|
||||
else:
|
||||
return "New bots will not automatically be given any roles."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class role_persistence(stypes.Boolean, GuildSetting):
|
||||
category = "New Members"
|
||||
|
||||
attr_name = "role_persistence"
|
||||
|
||||
_data_column = 'persist_roles'
|
||||
|
||||
display_name = "role_persistence"
|
||||
desc = "Whether to remember member roles when they leave the server."
|
||||
_outputs = {True: "Enabled", False: "Disabled"}
|
||||
_default = True
|
||||
|
||||
long_desc = (
|
||||
"When enabled, restores member roles when they rejoin the server.\n"
|
||||
"This enables profile roles and purchased roles, such as field of study and colour roles, "
|
||||
"as well as moderation roles, "
|
||||
"such as the studyban and mute roles, to persist even when a member leaves and rejoins.\n"
|
||||
"Note: Members who leave while this is disabled will not have their roles restored."
|
||||
)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Roles will now be restored when a member rejoins."
|
||||
else:
|
||||
return "Member roles will no longer be saved or restored."
|
||||
@@ -0,0 +1,6 @@
|
||||
from .module import module
|
||||
|
||||
from . import data
|
||||
from . import settings
|
||||
from . import tracker
|
||||
from . import command
|
||||
@@ -0,0 +1,943 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from discord import PartialEmoji
|
||||
|
||||
from cmdClient.lib import ResponseTimedOut, UserCancelled
|
||||
from wards import guild_admin
|
||||
from settings import UserInputError
|
||||
from utils.lib import tick, cross
|
||||
|
||||
from .module import module
|
||||
from .tracker import ReactionRoleMessage
|
||||
from .data import reaction_role_reactions, reaction_role_messages
|
||||
from . import settings
|
||||
|
||||
|
||||
example_emoji = "🧮"
|
||||
example_str = "🧮 mathematics, 🫀 biology, 💻 computer science, 🖼️ design, 🩺 medicine"
|
||||
|
||||
|
||||
def _parse_messageref(ctx):
|
||||
"""
|
||||
Parse a message reference from the context message and return it.
|
||||
Removes the parsed string from `ctx.args` if applicable.
|
||||
Supports the following reference types, in precedence order:
|
||||
- A Discord message reply reference.
|
||||
- A message link.
|
||||
- A message id.
|
||||
|
||||
Returns: (channelid, messageid)
|
||||
`messageid` will be `None` if a valid reference was not found.
|
||||
`channelid` will be `None` if the message was provided by pure id.
|
||||
"""
|
||||
target_id = None
|
||||
target_chid = None
|
||||
|
||||
if ctx.msg.reference:
|
||||
# True message reference extract message and return
|
||||
target_id = ctx.msg.reference.message_id
|
||||
target_chid = ctx.msg.reference.channel_id
|
||||
elif ctx.args:
|
||||
# Parse the first word of the message arguments
|
||||
splits = ctx.args.split(maxsplit=1)
|
||||
maybe_target = splits[0]
|
||||
|
||||
# Expect a message id or message link
|
||||
if maybe_target.isdigit():
|
||||
# Assume it is a message id
|
||||
target_id = int(maybe_target)
|
||||
elif '/' in maybe_target:
|
||||
# Assume it is a link
|
||||
# Split out the channelid and messageid, if possible
|
||||
link_splits = maybe_target.rsplit('/', maxsplit=2)
|
||||
if len(link_splits) > 1 and link_splits[-1].isdigit() and link_splits[-2].isdigit():
|
||||
target_id = int(link_splits[-1])
|
||||
target_chid = int(link_splits[-2])
|
||||
|
||||
# If we found a target id, truncate the arguments
|
||||
if target_id is not None:
|
||||
if len(splits) > 1:
|
||||
ctx.args = splits[1].strip()
|
||||
else:
|
||||
ctx.args = ""
|
||||
else:
|
||||
# Last-ditch attempt, see if the argument could be a stored reaction
|
||||
maybe_emoji = maybe_target.strip(',')
|
||||
guild_message_rows = reaction_role_messages.fetch_rows_where(guildid=ctx.guild.id)
|
||||
messages = [ReactionRoleMessage.fetch(row.messageid) for row in guild_message_rows]
|
||||
emojis = {reaction.emoji: message for message in messages for reaction in message.reactions}
|
||||
emoji_name_map = {emoji.name.lower(): emoji for emoji in emojis}
|
||||
emoji_id_map = {emoji.id: emoji for emoji in emojis if emoji.id}
|
||||
result = _parse_emoji(maybe_emoji, emoji_name_map, emoji_id_map)
|
||||
if result and result in emojis:
|
||||
message = emojis[result]
|
||||
target_id = message.messageid
|
||||
target_chid = message.data.channelid
|
||||
|
||||
# Return the message reference
|
||||
return (target_chid, target_id)
|
||||
|
||||
|
||||
def _parse_emoji(emoji_str, name_map, id_map):
|
||||
"""
|
||||
Extract a PartialEmoji from a user provided emoji string, given the accepted raw names and ids.
|
||||
"""
|
||||
emoji = None
|
||||
if len(emoji_str) < 10 and all(ord(char) >= 256 for char in emoji_str):
|
||||
# The string is pure unicode, we assume built in emoji
|
||||
emoji = PartialEmoji(name=emoji_str)
|
||||
elif emoji_str.lower() in name_map:
|
||||
emoji = name_map[emoji_str.lower()]
|
||||
elif emoji_str.isdigit() and int(emoji_str) in id_map:
|
||||
emoji = id_map[int(emoji_str)]
|
||||
else:
|
||||
# Attempt to parse as custom emoji
|
||||
# Accept custom emoji provided in the full form
|
||||
emoji_split = emoji_str.strip('<>:').split(':')
|
||||
if len(emoji_split) in (2, 3) and emoji_split[-1].isdigit():
|
||||
emoji_id = int(emoji_split[-1])
|
||||
emoji_name = emoji_split[-2]
|
||||
emoji_animated = emoji_split[0] == 'a'
|
||||
emoji = PartialEmoji(
|
||||
name=emoji_name,
|
||||
id=emoji_id,
|
||||
animated=emoji_animated
|
||||
)
|
||||
return emoji
|
||||
|
||||
|
||||
async def reaction_ask(ctx, question, timeout=120, timeout_msg=None, cancel_msg=None):
|
||||
"""
|
||||
Asks the author the provided question in an embed, and provides check/cross reactions for answering.
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=question
|
||||
)
|
||||
out_msg = await ctx.reply(embed=embed)
|
||||
|
||||
# Wait for a tick/cross
|
||||
asyncio.create_task(out_msg.add_reaction(tick))
|
||||
asyncio.create_task(out_msg.add_reaction(cross))
|
||||
|
||||
def check(reaction, user):
|
||||
result = True
|
||||
result = result and reaction.message == out_msg
|
||||
result = result and user == ctx.author
|
||||
result = result and (reaction.emoji == tick or reaction.emoji == cross)
|
||||
return result
|
||||
|
||||
try:
|
||||
reaction, _ = await ctx.client.wait_for(
|
||||
'reaction_add',
|
||||
check=check,
|
||||
timeout=120
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
await out_msg.edit(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=timeout_msg or "Prompt timed out."
|
||||
)
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
raise ResponseTimedOut from None
|
||||
if reaction.emoji == cross:
|
||||
try:
|
||||
await out_msg.edit(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=cancel_msg or "Cancelled."
|
||||
)
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
raise UserCancelled from None
|
||||
|
||||
try:
|
||||
await out_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_message_setting_flags = {
|
||||
'removable': settings.removable,
|
||||
'maximum': settings.maximum,
|
||||
'required_role': settings.required_role,
|
||||
'log': settings.log,
|
||||
'refunds': settings.refunds,
|
||||
'default_price': settings.default_price,
|
||||
}
|
||||
_reaction_setting_flags = {
|
||||
'price': settings.price,
|
||||
'duration': settings.duration
|
||||
}
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"reactionroles",
|
||||
group="Guild Configuration",
|
||||
desc="Create and configure reaction role messages.",
|
||||
aliases=('rroles',),
|
||||
flags=(
|
||||
'delete', 'remove==',
|
||||
'enable', 'disable',
|
||||
'required_role==', 'removable=', 'maximum=', 'refunds=', 'log=', 'default_price=',
|
||||
'price=', 'duration=='
|
||||
)
|
||||
)
|
||||
@guild_admin()
|
||||
async def cmd_reactionroles(ctx, flags):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}rroles
|
||||
{prefix}rroles [enable|disable|delete] msglink
|
||||
{prefix}rroles msglink [emoji1 role1, emoji2 role2, ...]
|
||||
{prefix}rroles msglink --remove emoji1, emoji2, ...
|
||||
{prefix}rroles msglink --message_setting [value]
|
||||
{prefix}rroles msglink emoji --reaction_setting [value]
|
||||
Description:
|
||||
Create and configure "reaction roles", i.e. roles obtainable by \
|
||||
clicking reactions on a particular message.
|
||||
`msglink` is the link or message id of the message with reactions.
|
||||
`emoji` should be given as the emoji itself, or the name or id.
|
||||
`role` may be given by name, mention, or id.
|
||||
Getting started:
|
||||
First choose the message you want to add reaction roles to, \
|
||||
and copy the link or message id for that message. \
|
||||
Then run the command `{prefix}rroles link`, replacing `link` with the copied link, \
|
||||
and follow the prompts.
|
||||
For faster setup, use `{prefix}rroles link emoji1 role1, emoji2 role2` instead.
|
||||
Editing reaction roles:
|
||||
Remove roles with `{prefix}rroles link --remove emoji1, emoji2, ...`
|
||||
Add/edit roles with `{prefix}rroles link emoji1 role1, emoji2 role2, ...`
|
||||
Examples``:
|
||||
{prefix}rroles {ctx.msg.id} 🧮 mathematics, 🫀 biology, 🩺 medicine
|
||||
{prefix}rroles disable {ctx.msg.id}
|
||||
PAGEBREAK:
|
||||
Page 2
|
||||
Advanced configuration:
|
||||
Type `{prefix}rroles link` again to view the advanced setting window, \
|
||||
and use `{prefix}rroles link --setting value` to modify the settings. \
|
||||
See below for descriptions of each message setting.
|
||||
For example to disable event logging, run `{prefix}rroles link --log off`.
|
||||
|
||||
For per-reaction settings, instead use `{prefix}rroles link emoji --setting value`.
|
||||
|
||||
*(!) Replace `setting` with one of the settings below!*
|
||||
Message Settings::
|
||||
maximum: Maximum number of roles obtainable from this message.
|
||||
log: Whether to log reaction role usage into the event log.
|
||||
removable: Whether the reactions roles can be remove by unreacting.
|
||||
refunds: Whether to refund the role price when removing the role.
|
||||
default_price: The default price of each role on this message.
|
||||
required_role: The role required to use these reactions roles.
|
||||
Reaction Settings::
|
||||
price: The price of this reaction role. (May be negative for a reward.)
|
||||
tduration: How long this role will last after being selected or bought.
|
||||
Configuration Examples``:
|
||||
{prefix}rroles {ctx.msg.id} --maximum 5
|
||||
{prefix}rroles {ctx.msg.id} --default_price 20
|
||||
{prefix}rroles {ctx.msg.id} --required_role None
|
||||
{prefix}rroles {ctx.msg.id} 🧮 --price 1024
|
||||
{prefix}rroles {ctx.msg.id} 🧮 --duration 7 days
|
||||
"""
|
||||
if not ctx.args:
|
||||
# No target message provided, list the current reaction messages
|
||||
# Or give a brief guide if there are no current reaction messages
|
||||
guild_message_rows = reaction_role_messages.fetch_rows_where(guildid=ctx.guild.id)
|
||||
if guild_message_rows:
|
||||
# List messages
|
||||
|
||||
# First get the list of reaction role messages in the guild
|
||||
messages = [ReactionRoleMessage.fetch(row.messageid) for row in guild_message_rows]
|
||||
|
||||
# Sort them by channelid and messageid
|
||||
messages.sort(key=lambda m: (m.data.channelid, m.messageid))
|
||||
|
||||
# Build the message description strings
|
||||
message_strings = []
|
||||
for message in messages:
|
||||
header = (
|
||||
"`{}` in <#{}> ([Click to jump]({})){}".format(
|
||||
message.messageid,
|
||||
message.data.channelid,
|
||||
message.message_link,
|
||||
" (disabled)" if not message.enabled else ""
|
||||
)
|
||||
)
|
||||
role_strings = [
|
||||
"{} <@&{}>".format(reaction.emoji, reaction.data.roleid)
|
||||
for reaction in message.reactions
|
||||
]
|
||||
role_string = '\n'.join(role_strings) or "No reaction roles!"
|
||||
|
||||
message_strings.append("{}\n{}".format(header, role_string))
|
||||
|
||||
pages = []
|
||||
page = []
|
||||
page_len = 0
|
||||
page_chars = 0
|
||||
i = 0
|
||||
while i < len(message_strings):
|
||||
message_string = message_strings[i]
|
||||
chars = len(message_string)
|
||||
lines = len(message_string.splitlines())
|
||||
if (page and lines + page_len > 20) or (chars + page_chars > 2000):
|
||||
pages.append('\n\n'.join(page))
|
||||
page = []
|
||||
page_len = 0
|
||||
page_chars = 0
|
||||
else:
|
||||
page.append(message_string)
|
||||
page_len += lines
|
||||
page_chars += chars
|
||||
i += 1
|
||||
if page:
|
||||
pages.append('\n\n'.join(page))
|
||||
|
||||
page_count = len(pages)
|
||||
title = "Reaction Roles in {}".format(ctx.guild.name)
|
||||
embeds = [
|
||||
discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=page,
|
||||
title=title
|
||||
)
|
||||
for page in pages
|
||||
]
|
||||
if page_count > 1:
|
||||
[embed.set_footer(text="Page {} of {}".format(i + 1, page_count)) for i, embed in enumerate(embeds)]
|
||||
await ctx.pager(embeds)
|
||||
else:
|
||||
# Send a setup guide
|
||||
embed = discord.Embed(
|
||||
title="No Reaction Roles set up!",
|
||||
description=(
|
||||
"To setup reaction roles, first copy the link or message id of the message you want to "
|
||||
"add the roles to. Then run `{prefix}rroles link`, replacing `link` with the link you copied, "
|
||||
"and follow the prompts.\n"
|
||||
"See `{prefix}help rroles` for more information.".format(prefix=ctx.best_prefix)
|
||||
),
|
||||
colour=discord.Colour.orange()
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
return
|
||||
|
||||
# Extract first word, look for a subcommand
|
||||
splits = ctx.args.split(maxsplit=1)
|
||||
subcmd = splits[0].lower()
|
||||
|
||||
if subcmd in ('enable', 'disable', 'delete'):
|
||||
# Truncate arguments and extract target
|
||||
if len(splits) > 1:
|
||||
ctx.args = splits[1]
|
||||
target_chid, target_id = _parse_messageref(ctx)
|
||||
else:
|
||||
target_chid = None
|
||||
target_id = None
|
||||
ctx.args = ''
|
||||
|
||||
# Handle subcommand special cases
|
||||
if subcmd == 'enable':
|
||||
if ctx.args and not target_id:
|
||||
await ctx.error_reply(
|
||||
"Couldn't find the message to enable!\n"
|
||||
"**Usage:** `{}rroles enable [message link or id]`.".format(ctx.best_prefix)
|
||||
)
|
||||
elif not target_id:
|
||||
# Confirm enabling of all reaction messages
|
||||
await reaction_ask(
|
||||
ctx,
|
||||
"Are you sure you want to enable all reaction role messages in this server?",
|
||||
timeout_msg="Prompt timed out, no reaction roles enabled.",
|
||||
cancel_msg="User cancelled, no reaction roles enabled."
|
||||
)
|
||||
reaction_role_messages.update_where(
|
||||
{'enabled': True},
|
||||
guildid=ctx.guild.id
|
||||
)
|
||||
await ctx.embed_reply(
|
||||
"All reaction role messages have been enabled.",
|
||||
colour=discord.Colour.green(),
|
||||
)
|
||||
else:
|
||||
# Fetch the target
|
||||
target = ReactionRoleMessage.fetch(target_id)
|
||||
if target is None:
|
||||
await ctx.error_reply(
|
||||
"This message doesn't have any reaction roles!\n"
|
||||
"Run the command again without `enable` to assign reaction roles."
|
||||
)
|
||||
else:
|
||||
# We have a valid target
|
||||
if target.enabled:
|
||||
await ctx.error_reply(
|
||||
"This message is already enabled!"
|
||||
)
|
||||
else:
|
||||
target.enabled = True
|
||||
await ctx.embed_reply(
|
||||
"The message has been enabled!"
|
||||
)
|
||||
elif subcmd == 'disable':
|
||||
if ctx.args and not target_id:
|
||||
await ctx.error_reply(
|
||||
"Couldn't find the message to disable!\n"
|
||||
"**Usage:** `{}rroles disable [message link or id]`.".format(ctx.best_prefix)
|
||||
)
|
||||
elif not target_id:
|
||||
# Confirm disabling of all reaction messages
|
||||
await reaction_ask(
|
||||
ctx,
|
||||
"Are you sure you want to disable all reaction role messages in this server?",
|
||||
timeout_msg="Prompt timed out, no reaction roles disabled.",
|
||||
cancel_msg="User cancelled, no reaction roles disabled."
|
||||
)
|
||||
reaction_role_messages.update_where(
|
||||
{'enabled': False},
|
||||
guildid=ctx.guild.id
|
||||
)
|
||||
await ctx.embed_reply(
|
||||
"All reaction role messages have been disabled.",
|
||||
colour=discord.Colour.green(),
|
||||
)
|
||||
else:
|
||||
# Fetch the target
|
||||
target = ReactionRoleMessage.fetch(target_id)
|
||||
if target is None:
|
||||
await ctx.error_reply(
|
||||
"This message doesn't have any reaction roles! Nothing to disable."
|
||||
)
|
||||
else:
|
||||
# We have a valid target
|
||||
if not target.enabled:
|
||||
await ctx.error_reply(
|
||||
"This message is already disabled!"
|
||||
)
|
||||
else:
|
||||
target.enabled = False
|
||||
await ctx.embed_reply(
|
||||
"The message has been disabled!"
|
||||
)
|
||||
elif subcmd == 'delete':
|
||||
if ctx.args and not target_id:
|
||||
await ctx.error_reply(
|
||||
"Couldn't find the message to remove!\n"
|
||||
"**Usage:** `{}rroles remove [message link or id]`.".format(ctx.best_prefix)
|
||||
)
|
||||
elif not target_id:
|
||||
# Confirm disabling of all reaction messages
|
||||
await reaction_ask(
|
||||
ctx,
|
||||
"Are you sure you want to remove all reaction role messages in this server?",
|
||||
timeout_msg="Prompt timed out, no messages removed.",
|
||||
cancel_msg="User cancelled, no messages removed."
|
||||
)
|
||||
reaction_role_messages.delete_where(
|
||||
guildid=ctx.guild.id
|
||||
)
|
||||
await ctx.embed_reply(
|
||||
"All reaction role messages have been removed.",
|
||||
colour=discord.Colour.green(),
|
||||
)
|
||||
else:
|
||||
# Fetch the target
|
||||
target = ReactionRoleMessage.fetch(target_id)
|
||||
if target is None:
|
||||
await ctx.error_reply(
|
||||
"This message doesn't have any reaction roles! Nothing to remove."
|
||||
)
|
||||
else:
|
||||
# We have a valid target
|
||||
target.delete()
|
||||
await ctx.embed_reply(
|
||||
"The message has been removed and is no longer a reaction role message."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Just extract target
|
||||
target_chid, target_id = _parse_messageref(ctx)
|
||||
|
||||
# Handle target parsing issue
|
||||
if target_id is None:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't parse `{}` as a message id or message link!\n"
|
||||
"See `{}help rroles` for detailed usage information.".format(ctx.args.split()[0], ctx.best_prefix)
|
||||
)
|
||||
|
||||
# Get the associated ReactionRoleMessage, if it exists
|
||||
target = ReactionRoleMessage.fetch(target_id)
|
||||
|
||||
# Get the target message
|
||||
if target:
|
||||
message = await target.fetch_message()
|
||||
if not message:
|
||||
# TODO: Consider offering some sort of `move` option here.
|
||||
await ctx.error_reply(
|
||||
"This reaction role message no longer exists!\n"
|
||||
"Use `{}rroles delete {}` to remove it from the list.".format(ctx.best_prefix, target.messageid)
|
||||
)
|
||||
else:
|
||||
message = None
|
||||
if target_chid:
|
||||
channel = ctx.guild.get_channel(target_chid)
|
||||
if not channel:
|
||||
await ctx.error_reply(
|
||||
"The provided channel no longer exists!"
|
||||
)
|
||||
elif not isinstance(channel, discord.TextChannel):
|
||||
await ctx.error_reply(
|
||||
"The provided channel is not a text channel!"
|
||||
)
|
||||
else:
|
||||
message = await channel.fetch_message(target_id)
|
||||
if not message:
|
||||
await ctx.error_reply(
|
||||
"Couldn't find the specified message in {}!".format(channel.mention)
|
||||
)
|
||||
else:
|
||||
out_msg = await ctx.embed_reply("Searching for `{}`".format(target_id))
|
||||
message = await ctx.find_message(target_id)
|
||||
try:
|
||||
await out_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
if not message:
|
||||
await ctx.error_reply(
|
||||
"Couldn't find the message `{}`!".format(target_id)
|
||||
)
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Handle the `remove` flag specially
|
||||
# In particular, all other flags are ignored
|
||||
if flags['remove']:
|
||||
if not target:
|
||||
await ctx.error_reply(
|
||||
"The specified message has no reaction roles! Nothing to remove."
|
||||
)
|
||||
else:
|
||||
# Parse emojis and remove from target
|
||||
target_emojis = {reaction.emoji: reaction for reaction in target.reactions}
|
||||
emoji_name_map = {emoji.name.lower(): emoji for emoji in target_emojis}
|
||||
emoji_id_map = {emoji.id: emoji for emoji in target_emojis}
|
||||
|
||||
items = [item.strip() for item in flags['remove'].split(',')]
|
||||
to_remove = [] # List of reactions to remove
|
||||
for emoji_str in items:
|
||||
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
|
||||
if emoji is None:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't parse `{}` as an emoji! No reactions were removed.".format(emoji_str)
|
||||
)
|
||||
if emoji not in target_emojis:
|
||||
return await ctx.error_reply(
|
||||
"{} is not a reaction role for this message!".format(emoji)
|
||||
)
|
||||
to_remove.append(target_emojis[emoji])
|
||||
|
||||
# Delete reactions from data
|
||||
description = '\n'.join("{} <@&{}>".format(reaction.emoji, reaction.data.roleid) for reaction in to_remove)
|
||||
reaction_role_reactions.delete_where(reactionid=[reaction.reactionid for reaction in to_remove])
|
||||
target.refresh()
|
||||
|
||||
# Ack
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
title="Reaction Roles deactivated",
|
||||
description=description
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
return
|
||||
|
||||
# Any remaining arguments should be emoji specifications with optional role
|
||||
# Parse these now
|
||||
given_emojis = {} # Map PartialEmoji -> Optional[Role]
|
||||
existing_emojis = set() # Set of existing reaction emoji identifiers
|
||||
|
||||
if ctx.args:
|
||||
# First build the list of custom emojis we can accept by name
|
||||
# We do this by reverse precedence, so the highest priority emojis are added last
|
||||
custom_emojis = []
|
||||
custom_emojis.extend(ctx.guild.emojis) # Custom emojis in the guild
|
||||
if target:
|
||||
custom_emojis.extend([r.emoji for r in target.reactions]) # Configured reaction roles on the target
|
||||
custom_emojis.extend([r.emoji for r in message.reactions if r.custom_emoji]) # Actual reactions on the message
|
||||
|
||||
# Filter out the built in emojis and those without a name
|
||||
custom_emojis = (emoji for emoji in custom_emojis if emoji.name and emoji.id)
|
||||
|
||||
# Build the maps to lookup provided custom emojis
|
||||
emoji_name_map = {emoji.name.lower(): emoji for emoji in custom_emojis}
|
||||
emoji_id_map = {emoji.id: emoji for emoji in custom_emojis}
|
||||
|
||||
# Now parse the provided emojis
|
||||
# Assume that all-unicode strings are built-in emojis
|
||||
# We can't assume much else unless we have a list of such emojis
|
||||
splits = (split.strip() for line in ctx.args.splitlines() for split in line.split(',') if split)
|
||||
splits = (split.split(maxsplit=1) for split in splits if split)
|
||||
arg_emoji_strings = {
|
||||
split[0]: split[1] if len(split) > 1 else None
|
||||
for split in splits
|
||||
} # emoji_str -> Optional[role_str]
|
||||
|
||||
arg_emoji_map = {}
|
||||
for emoji_str, role_str in arg_emoji_strings.items():
|
||||
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
|
||||
if emoji is None:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't parse `{}` as an emoji!".format(emoji_str)
|
||||
)
|
||||
else:
|
||||
arg_emoji_map[emoji] = role_str
|
||||
|
||||
# Final pass extracts roles
|
||||
# If any new emojis were provided, their roles should be specified, we enforce this during role parsing
|
||||
# First collect the existing emoji strings
|
||||
if target:
|
||||
for reaction in target.reactions:
|
||||
emoji_id = reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id
|
||||
existing_emojis.add(emoji_id)
|
||||
|
||||
# Now parse and assign the roles, building the final map
|
||||
for emoji, role_str in arg_emoji_map.items():
|
||||
emoji_id = emoji.name if emoji.id is None else emoji.id
|
||||
role = None
|
||||
if role_str:
|
||||
role = await ctx.find_role(role_str, create=True, interactive=True, allow_notfound=False)
|
||||
elif emoji_id not in existing_emojis:
|
||||
return await ctx.error_reply(
|
||||
"New emoji {} was given without an associated role!".format(emoji)
|
||||
)
|
||||
given_emojis[emoji] = role
|
||||
|
||||
# Next manage target creation or emoji editing, if required
|
||||
if target is None:
|
||||
# Reaction message creation wizard
|
||||
# Confirm that they want to create a new reaction role message.
|
||||
await reaction_ask(
|
||||
ctx,
|
||||
question="Do you want to set up new reaction roles for [this message]({})?".format(
|
||||
message.jump_url
|
||||
),
|
||||
timeout_msg="Prompt timed out, no reaction roles created.",
|
||||
cancel_msg="Reaction Role creation cancelled."
|
||||
)
|
||||
|
||||
# Continue with creation
|
||||
# Obtain emojis if not already provided
|
||||
if not given_emojis:
|
||||
# Prompt for the initial emojis
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
title="What reaction roles would you like to add?",
|
||||
description=(
|
||||
"Please now type the reaction roles you would like to add "
|
||||
"in the form `emoji role`, where `role` is given by partial name or id. For example:"
|
||||
"```{}```".format(example_str)
|
||||
)
|
||||
)
|
||||
out_msg = await ctx.reply(embed=embed)
|
||||
|
||||
# Wait for a response
|
||||
def check(msg):
|
||||
return msg.author == ctx.author and msg.channel == ctx.ch and msg.content
|
||||
|
||||
try:
|
||||
reply = await ctx.client.wait_for('message', check=check, timeout=300)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
await out_msg.edit(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description="Prompt timed out, no reaction roles created."
|
||||
)
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return
|
||||
|
||||
rolestrs = reply.content
|
||||
|
||||
try:
|
||||
await reply.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Attempt to parse the emojis
|
||||
# First build the list of custom emojis we can accept by name
|
||||
custom_emojis = []
|
||||
custom_emojis.extend(ctx.guild.emojis) # Custom emojis in the guild
|
||||
custom_emojis.extend(
|
||||
r.emoji for r in message.reactions if r.custom_emoji
|
||||
) # Actual reactions on the message
|
||||
|
||||
# Filter out the built in emojis and those without a name
|
||||
custom_emojis = (emoji for emoji in custom_emojis if emoji.name and emoji.id)
|
||||
|
||||
# Build the maps to lookup provided custom emojis
|
||||
emoji_name_map = {emoji.name.lower(): emoji for emoji in custom_emojis}
|
||||
emoji_id_map = {emoji.id: emoji for emoji in custom_emojis}
|
||||
|
||||
# Now parse the provided emojis
|
||||
# Assume that all-unicode strings are built-in emojis
|
||||
# We can't assume much else unless we have a list of such emojis
|
||||
splits = (split.strip() for line in rolestrs.splitlines() for split in line.split(',') if split)
|
||||
splits = (split.split(maxsplit=1) for split in splits if split)
|
||||
arg_emoji_strings = {
|
||||
split[0]: split[1] if len(split) > 1 else None
|
||||
for split in splits
|
||||
} # emoji_str -> Optional[role_str]
|
||||
|
||||
# Check all the emojis have roles associated
|
||||
for emoji_str, role_str in arg_emoji_strings.items():
|
||||
if role_str is None:
|
||||
return await ctx.error_reply(
|
||||
"No role provided for `{}`! Reaction role creation cancelled.".format(emoji_str)
|
||||
)
|
||||
|
||||
# Parse the provided roles and emojis
|
||||
for emoji_str, role_str in arg_emoji_strings.items():
|
||||
emoji = _parse_emoji(emoji_str, emoji_name_map, emoji_id_map)
|
||||
if emoji is None:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't parse `{}` as an emoji!".format(emoji_str)
|
||||
)
|
||||
else:
|
||||
given_emojis[emoji] = await ctx.find_role(
|
||||
role_str,
|
||||
create=True,
|
||||
interactive=True,
|
||||
allow_notfound=False
|
||||
)
|
||||
|
||||
if len(given_emojis) > 20:
|
||||
return await ctx.error_reply("A maximum of 20 reactions are possible per message! Cancelling creation.")
|
||||
|
||||
# Create the ReactionRoleMessage
|
||||
target = ReactionRoleMessage.create(
|
||||
message.id,
|
||||
message.guild.id,
|
||||
message.channel.id
|
||||
)
|
||||
|
||||
# Insert the reaction data directly
|
||||
reaction_role_reactions.insert_many(
|
||||
*((message.id, role.id, emoji.name, emoji.id, emoji.animated) for emoji, role in given_emojis.items()),
|
||||
insert_keys=('messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated')
|
||||
)
|
||||
|
||||
# Refresh the message to pick up the new reactions
|
||||
target.refresh()
|
||||
|
||||
# Add the reactions to the message, if possible
|
||||
existing_reactions = set(
|
||||
reaction.emoji if not reaction.custom_emoji else
|
||||
(reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id)
|
||||
for reaction in message.reactions
|
||||
)
|
||||
missing = [
|
||||
reaction.emoji for reaction in target.reactions
|
||||
if (reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id) not in existing_reactions
|
||||
]
|
||||
if not any(emoji.id not in set(cemoji.id for cemoji in ctx.guild.emojis) for emoji in missing if emoji.id):
|
||||
# We can add the missing emojis
|
||||
for emoji in missing:
|
||||
try:
|
||||
await message.add_reaction(emoji)
|
||||
except discord.HTTPException:
|
||||
break
|
||||
else:
|
||||
missing = []
|
||||
|
||||
# Ack the creation
|
||||
ack_msg = "Created `{}` new reaction roles on [this message]({})!".format(
|
||||
len(target.reactions),
|
||||
target.message_link
|
||||
)
|
||||
if missing:
|
||||
ack_msg += "\nPlease add the missing reactions to the message!"
|
||||
await ctx.embed_reply(
|
||||
ack_msg
|
||||
)
|
||||
elif given_emojis:
|
||||
# Update the target reactions
|
||||
# Create a map of the emojis that need to be added or updated
|
||||
needs_update = {
|
||||
emoji: role for emoji, role in given_emojis.items() if role
|
||||
}
|
||||
|
||||
# Fetch the existing target emojis to split the roles into inserts and updates
|
||||
target_emojis = {reaction.emoji: reaction for reaction in target.reactions}
|
||||
|
||||
# Handle the new roles
|
||||
insert_targets = {
|
||||
emoji: role for emoji, role in needs_update.items() if emoji not in target_emojis
|
||||
}
|
||||
if insert_targets:
|
||||
if len(insert_targets) + len(target_emojis) > 20:
|
||||
return await ctx.error_reply("Too many reactions! A maximum of 20 reactions are possible per message!")
|
||||
reaction_role_reactions.insert_many(
|
||||
*(
|
||||
(message.id, role.id, emoji.name, emoji.id, emoji.animated)
|
||||
for emoji, role in insert_targets.items()
|
||||
),
|
||||
insert_keys=('messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated')
|
||||
)
|
||||
# Handle the updated roles
|
||||
update_targets = {
|
||||
target_emojis[emoji]: role for emoji, role in needs_update.items() if emoji in target_emojis
|
||||
}
|
||||
if update_targets:
|
||||
reaction_role_reactions.update_many(
|
||||
*((role.id, reaction.reactionid) for reaction, role in update_targets.items()),
|
||||
set_keys=('roleid',),
|
||||
where_keys=('reactionid',),
|
||||
)
|
||||
|
||||
# Finally, refresh to load the new reactions
|
||||
target.refresh()
|
||||
|
||||
# Now that the target is created/updated, all the provided emojis should be reactions
|
||||
given_reactions = []
|
||||
if given_emojis:
|
||||
# Make a map of the existing reactions
|
||||
existing_reactions = {
|
||||
reaction.emoji.name if reaction.emoji.id is None else reaction.emoji.id: reaction
|
||||
for reaction in target.reactions
|
||||
}
|
||||
given_reactions = [
|
||||
existing_reactions[emoji.name if emoji.id is None else emoji.id]
|
||||
for emoji in given_emojis
|
||||
]
|
||||
|
||||
# Handle message setting updates
|
||||
update_lines = [] # Setting update lines to display
|
||||
update_columns = {} # Message data columns to update
|
||||
for flag in _message_setting_flags:
|
||||
if flags[flag]:
|
||||
setting_class = _message_setting_flags[flag]
|
||||
try:
|
||||
setting = await setting_class.parse(target.messageid, ctx, flags[flag])
|
||||
except UserInputError as e:
|
||||
return await ctx.error_reply(
|
||||
"{} {}\nNo settings were modified.".format(cross, e.msg),
|
||||
title="Couldn't save settings!"
|
||||
)
|
||||
else:
|
||||
update_lines.append(
|
||||
"{} {}".format(tick, setting.success_response)
|
||||
)
|
||||
update_columns[setting._data_column] = setting.data
|
||||
if update_columns:
|
||||
# First write the data
|
||||
reaction_role_messages.update_where(
|
||||
update_columns,
|
||||
messageid=target.messageid
|
||||
)
|
||||
# Then ack the setting update
|
||||
if len(update_lines) > 1:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
title="Reaction Role message settings updated!",
|
||||
description='\n'.join(update_lines)
|
||||
)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
description=update_lines[0]
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
# Handle reaction setting updates
|
||||
update_lines = [] # Setting update lines to display
|
||||
update_columns = {} # Message data columns to update, for all given reactions
|
||||
reactions = given_reactions or target.reactions
|
||||
for flag in _reaction_setting_flags:
|
||||
for reaction in reactions:
|
||||
if flags[flag]:
|
||||
setting_class = _reaction_setting_flags[flag]
|
||||
try:
|
||||
setting = await setting_class.parse(reaction.reactionid, ctx, flags[flag])
|
||||
except UserInputError as e:
|
||||
return await ctx.error_reply(
|
||||
"{} {}\nNo reaction roles were modified.".format(cross, e.msg),
|
||||
title="Couldn't save reaction role settings!",
|
||||
)
|
||||
else:
|
||||
update_lines.append(
|
||||
setting.success_response.format(reaction=reaction)
|
||||
)
|
||||
update_columns[setting._data_column] = setting.data
|
||||
if update_columns:
|
||||
# First write the data
|
||||
reaction_role_reactions.update_where(
|
||||
update_columns,
|
||||
reactionid=[reaction.reactionid for reaction in reactions]
|
||||
)
|
||||
# Then ack the setting update
|
||||
if len(update_lines) > 1:
|
||||
blocks = ['\n'.join(update_lines[i:i+20]) for i in range(0, len(update_lines), 20)]
|
||||
embeds = [
|
||||
discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
title="Reaction Role settings updated!",
|
||||
description=block
|
||||
) for block in blocks
|
||||
]
|
||||
await ctx.pager(embeds)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.green(),
|
||||
description=update_lines[0]
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
# Show the reaction role message summary
|
||||
# Build the reaction fields
|
||||
reaction_fields = [] # List of tuples (name, value)
|
||||
for reaction in target.reactions:
|
||||
reaction_fields.append(
|
||||
(
|
||||
"{} {}".format(reaction.emoji.name, reaction.emoji if reaction.emoji.id else ''),
|
||||
"<@&{}>\n{}".format(reaction.data.roleid, reaction.settings.tabulated())
|
||||
)
|
||||
)
|
||||
|
||||
# Build the final setting pages
|
||||
description = (
|
||||
"{settings_table}\n"
|
||||
"To update a message setting: `{prefix}rroles messageid --setting value`\n"
|
||||
"To update an emoji setting: `{prefix}rroles messageid emoji --setting value`\n"
|
||||
"See examples and more usage information with `{prefix}help rroles`.\n"
|
||||
"**(!) Replace the `setting` with one of the settings on this page.**\n"
|
||||
).format(
|
||||
prefix=ctx.best_prefix,
|
||||
settings_table=target.settings.tabulated()
|
||||
)
|
||||
|
||||
field_blocks = [reaction_fields[i:i+6] for i in range(0, len(reaction_fields), 6)]
|
||||
page_count = len(field_blocks)
|
||||
embeds = []
|
||||
for i, block in enumerate(field_blocks):
|
||||
title = "Reaction role settings for message id `{}`".format(target.messageid)
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
description=description
|
||||
).set_author(
|
||||
name="Click to jump to message",
|
||||
url=target.message_link
|
||||
)
|
||||
for name, value in block:
|
||||
embed.add_field(name=name, value=value)
|
||||
if page_count > 1:
|
||||
embed.set_footer(text="Page {} of {}".format(i+1, page_count))
|
||||
embeds.append(embed)
|
||||
|
||||
# Finally, send the reaction role information
|
||||
await ctx.pager(embeds)
|
||||
@@ -0,0 +1,22 @@
|
||||
from data import Table, RowTable
|
||||
|
||||
|
||||
reaction_role_messages = RowTable(
|
||||
'reaction_role_messages',
|
||||
('messageid', 'guildid', 'channelid',
|
||||
'enabled',
|
||||
'required_role', 'allow_deselction',
|
||||
'max_obtainable', 'allow_refunds',
|
||||
'event_log'),
|
||||
'messageid'
|
||||
)
|
||||
|
||||
|
||||
reaction_role_reactions = RowTable(
|
||||
'reaction_role_reactions',
|
||||
('reactionid', 'messageid', 'roleid', 'emoji_name', 'emoji_id', 'emoji_animated', 'price', 'timeout'),
|
||||
'reactionid'
|
||||
)
|
||||
|
||||
|
||||
reaction_role_expiring = Table('reaction_role_expiring')
|
||||
172
src/modules/pending-rewrite/guild_admin/reaction_roles/expiry.py
Normal file
172
src/modules/pending-rewrite/guild_admin/reaction_roles/expiry.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import logging
|
||||
import traceback
|
||||
import asyncio
|
||||
import discord
|
||||
|
||||
from meta import client
|
||||
from utils.lib import utc_now
|
||||
from settings import GuildSettings
|
||||
|
||||
from .module import module
|
||||
from .data import reaction_role_expiring
|
||||
|
||||
_expiring = {}
|
||||
_wakeup_event = asyncio.Event()
|
||||
|
||||
|
||||
# TODO: More efficient data structure for min optimisation, e.g. pre-sorted with bisection insert
|
||||
|
||||
|
||||
# Public expiry interface
|
||||
def schedule_expiry(guildid, userid, roleid, expiry, reactionid=None):
|
||||
"""
|
||||
Schedule expiry of the given role for the given member at the given time.
|
||||
This will also cancel any existing expiry for this member, role pair.
|
||||
"""
|
||||
reaction_role_expiring.delete_where(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
roleid=roleid,
|
||||
)
|
||||
reaction_role_expiring.insert(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
roleid=roleid,
|
||||
expiry=expiry,
|
||||
reactionid=reactionid
|
||||
)
|
||||
key = (guildid, userid, roleid)
|
||||
_expiring[key] = expiry.timestamp()
|
||||
_wakeup_event.set()
|
||||
|
||||
|
||||
def cancel_expiry(*key):
|
||||
"""
|
||||
Cancel expiry for the given member and role, if it exists.
|
||||
"""
|
||||
guildid, userid, roleid = key
|
||||
reaction_role_expiring.delete_where(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
roleid=roleid,
|
||||
)
|
||||
if _expiring.pop(key, None) is not None:
|
||||
# Wakeup the expiry tracker for recalculation
|
||||
_wakeup_event.set()
|
||||
|
||||
|
||||
def _next():
|
||||
"""
|
||||
Calculate the next member, role pair to expire.
|
||||
"""
|
||||
if _expiring:
|
||||
key, _ = min(_expiring.items(), key=lambda pair: pair[1])
|
||||
return key
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def _expire(key):
|
||||
"""
|
||||
Execute reaction role expiry for the given member and role.
|
||||
This removes the role and logs the removal if applicable.
|
||||
If the user is no longer in the guild, it removes the role from the persistent roles instead.
|
||||
"""
|
||||
guildid, userid, roleid = key
|
||||
guild = client.get_guild(guildid)
|
||||
if guild:
|
||||
role = guild.get_role(roleid)
|
||||
if role:
|
||||
member = guild.get_member(userid)
|
||||
if member:
|
||||
log = GuildSettings(guildid).event_log.log
|
||||
if role in member.roles:
|
||||
# Remove role from member, and log if applicable
|
||||
try:
|
||||
await member.remove_roles(
|
||||
role,
|
||||
atomic=True,
|
||||
reason="Expiring temporary reaction role."
|
||||
)
|
||||
except discord.HTTPException:
|
||||
log(
|
||||
"Failed to remove expired reaction role {} from {}.".format(
|
||||
role.mention,
|
||||
member.mention
|
||||
),
|
||||
colour=discord.Colour.red(),
|
||||
title="Could not remove expired Reaction Role!"
|
||||
)
|
||||
else:
|
||||
log(
|
||||
"Removing expired reaction role {} from {}.".format(
|
||||
role.mention,
|
||||
member.mention
|
||||
),
|
||||
title="Reaction Role expired!"
|
||||
)
|
||||
else:
|
||||
# Remove role from stored persistent roles, if existent
|
||||
client.data.past_member_roles.delete_where(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
roleid=roleid
|
||||
)
|
||||
reaction_role_expiring.delete_where(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
roleid=roleid
|
||||
)
|
||||
|
||||
|
||||
async def _expiry_tracker(client):
|
||||
"""
|
||||
Track and launch role expiry.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
key = _next()
|
||||
diff = _expiring[key] - utc_now().timestamp() if key else None
|
||||
await asyncio.wait_for(_wakeup_event.wait(), timeout=diff)
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout means next doesn't exist or is ready to expire
|
||||
if key and key in _expiring and _expiring[key] <= utc_now().timestamp() + 1:
|
||||
_expiring.pop(key)
|
||||
asyncio.create_task(_expire(key))
|
||||
except Exception:
|
||||
# This should be impossible, but catch and log anyway
|
||||
client.log(
|
||||
"Exception occurred while tracking reaction role expiry. Exception traceback follows.\n{}".format(
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="REACTION_ROLE_EXPIRY",
|
||||
level=logging.ERROR
|
||||
)
|
||||
else:
|
||||
# Wakeup event means that we should recalculate next
|
||||
_wakeup_event.clear()
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def launch_expiry_tracker(client):
|
||||
"""
|
||||
Launch the role expiry tracker.
|
||||
"""
|
||||
asyncio.create_task(_expiry_tracker(client))
|
||||
client.log("Reaction role expiry tracker launched.", context="REACTION_ROLE_EXPIRY")
|
||||
|
||||
|
||||
@module.init_task
|
||||
def load_expiring_roles(client):
|
||||
"""
|
||||
Initialise the expiring reaction role map, and attach it to the client.
|
||||
"""
|
||||
rows = reaction_role_expiring.select_where()
|
||||
_expiring.clear()
|
||||
_expiring.update({(row['guildid'], row['userid'], row['roleid']): row['expiry'].timestamp() for row in rows})
|
||||
client.objects['expiring_reaction_roles'] = _expiring
|
||||
if _expiring:
|
||||
client.log(
|
||||
"Loaded {} expiring reaction roles.".format(len(_expiring)),
|
||||
context="REACTION_ROLE_EXPIRY"
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from LionModule import LionModule
|
||||
|
||||
module = LionModule("Reaction_Roles")
|
||||
@@ -0,0 +1,257 @@
|
||||
from utils.lib import DotDict
|
||||
from wards import guild_admin
|
||||
from settings import ObjectSettings, ColumnData, Setting
|
||||
import settings.setting_types as setting_types
|
||||
|
||||
from .data import reaction_role_messages, reaction_role_reactions
|
||||
|
||||
|
||||
class RoleMessageSettings(ObjectSettings):
|
||||
settings = DotDict()
|
||||
|
||||
|
||||
class RoleMessageSetting(ColumnData, Setting):
|
||||
_table_interface = reaction_role_messages
|
||||
_id_column = 'messageid'
|
||||
_create_row = False
|
||||
|
||||
write_ward = guild_admin
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class required_role(setting_types.Role, RoleMessageSetting):
|
||||
attr_name = 'required_role'
|
||||
_data_column = 'required_role'
|
||||
|
||||
display_name = "required_role"
|
||||
desc = "Role required to use the reaction roles."
|
||||
|
||||
long_desc = (
|
||||
"Members will be required to have the specified role to use the reactions on this message."
|
||||
)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members need {} to use these reaction roles.".format(self.formatted)
|
||||
else:
|
||||
return "All members can now use these reaction roles."
|
||||
|
||||
@classmethod
|
||||
def _get_guildid(cls, id: int, **kwargs):
|
||||
return reaction_role_messages.fetch(id).guildid
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class removable(setting_types.Boolean, RoleMessageSetting):
|
||||
attr_name = 'removable'
|
||||
_data_column = 'removable'
|
||||
|
||||
display_name = "removable"
|
||||
desc = "Whether the role is removable by deselecting the reaction."
|
||||
|
||||
long_desc = (
|
||||
"If enabled, the role will be removed when the reaction is deselected."
|
||||
)
|
||||
|
||||
_default = True
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members will be able to remove roles by unreacting."
|
||||
else:
|
||||
return "Members will not be able to remove the reaction roles."
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class maximum(setting_types.Integer, RoleMessageSetting):
|
||||
attr_name = 'maximum'
|
||||
_data_column = 'maximum'
|
||||
|
||||
display_name = "maximum"
|
||||
desc = "The maximum number of roles a member can get from this message."
|
||||
|
||||
long_desc = (
|
||||
"The maximum number of roles that a member can get from this message. "
|
||||
"They will be notified by DM if they attempt to add more.\n"
|
||||
"The `removable` setting should generally be enabled with this setting."
|
||||
)
|
||||
|
||||
accepts = "An integer number of roles, or `None` to remove the maximum."
|
||||
|
||||
_min = 0
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id, data, **kwargs):
|
||||
if data is None:
|
||||
return "No maximum!"
|
||||
else:
|
||||
return "`{}`".format(data)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members can get a maximum of `{}` roles from this message.".format(self.value)
|
||||
else:
|
||||
return "Members can now get all the roles from this mesage."
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class refunds(setting_types.Boolean, RoleMessageSetting):
|
||||
attr_name = 'refunds'
|
||||
_data_column = 'refunds'
|
||||
|
||||
display_name = "refunds"
|
||||
desc = "Whether a user will be refunded when they deselect a role."
|
||||
|
||||
long_desc = (
|
||||
"Whether to give the user a refund when they deselect a role by reaction. "
|
||||
"This has no effect if `removable` is not enabled, or if the role removed has no cost."
|
||||
)
|
||||
|
||||
_default = True
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members will get a refund when they remove a role."
|
||||
else:
|
||||
return "Members will not get a refund when they remove a role."
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class default_price(setting_types.Integer, RoleMessageSetting):
|
||||
attr_name = 'default_price'
|
||||
_data_column = 'default_price'
|
||||
|
||||
display_name = "default_price"
|
||||
desc = "Default price of reaction roles on this message."
|
||||
|
||||
long_desc = (
|
||||
"Reaction roles on this message will have this cost if they do not have an individual price set."
|
||||
)
|
||||
|
||||
accepts = "An integer number of coins. Use `0` or `None` to make roles free by default."
|
||||
|
||||
_default = 0
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id, data, **kwargs):
|
||||
if not data:
|
||||
return "Free"
|
||||
else:
|
||||
return "`{}` coins".format(data)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Reaction roles on this message will cost `{}` coins by default.".format(self.value)
|
||||
else:
|
||||
return "Reaction roles on this message will be free by default."
|
||||
|
||||
|
||||
@RoleMessageSettings.attach_setting
|
||||
class log(setting_types.Boolean, RoleMessageSetting):
|
||||
attr_name = 'log'
|
||||
_data_column = 'event_log'
|
||||
|
||||
display_name = "log"
|
||||
desc = "Whether to log reaction role usage in the event log."
|
||||
|
||||
long_desc = (
|
||||
"When enabled, roles added or removed with reactions will be logged in the configured event log."
|
||||
)
|
||||
|
||||
_default = True
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Role updates will now be logged."
|
||||
else:
|
||||
return "Role updates will not be logged."
|
||||
|
||||
|
||||
class ReactionSettings(ObjectSettings):
|
||||
settings = DotDict()
|
||||
|
||||
|
||||
class ReactionSetting(ColumnData, Setting):
|
||||
_table_interface = reaction_role_reactions
|
||||
_id_column = 'reactionid'
|
||||
_create_row = False
|
||||
|
||||
write_ward = guild_admin
|
||||
|
||||
|
||||
@ReactionSettings.attach_setting
|
||||
class price(setting_types.Integer, ReactionSetting):
|
||||
attr_name = 'price'
|
||||
_data_column = 'price'
|
||||
|
||||
display_name = "price"
|
||||
desc = "Price of this reaction role (may be negative)."
|
||||
|
||||
long_desc = (
|
||||
"The number of coins that will be deducted from the user when this reaction is used.\n"
|
||||
"The number may be negative, in order to give a reward when the member choses the reaction."
|
||||
)
|
||||
|
||||
accepts = "An integer number of coins. Use `0` to make the role free, or `None` to use the message default."
|
||||
_max = 2 ** 20
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
"""
|
||||
The default price is given by the ReactionMessage price setting.
|
||||
"""
|
||||
return default_price.get(self._table_interface.fetch(self.id).messageid).value
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id, data, **kwargs):
|
||||
if not data:
|
||||
return "Free"
|
||||
else:
|
||||
return "`{}` coins".format(data)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value is not None:
|
||||
return "{{reaction.emoji}} {{reaction.role.mention}} now costs `{}` coins.".format(self.value)
|
||||
else:
|
||||
return "{reaction.emoji} {reaction.role.mention} is now free."
|
||||
|
||||
|
||||
@ReactionSettings.attach_setting
|
||||
class duration(setting_types.Duration, ReactionSetting):
|
||||
attr_name = 'duration'
|
||||
_data_column = 'timeout'
|
||||
|
||||
display_name = "duration"
|
||||
desc = "How long this reaction role will last."
|
||||
|
||||
long_desc = (
|
||||
"If set, the reaction role will be removed after the configured duration. "
|
||||
"Note that this does not affect existing members with the role, or existing expiries."
|
||||
)
|
||||
|
||||
_default_multiplier = 3600
|
||||
_show_days = True
|
||||
_min = 600
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id, data, **kwargs):
|
||||
if data is None:
|
||||
return "Permanent"
|
||||
else:
|
||||
return super()._format_data(id, data, **kwargs)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value is not None:
|
||||
return "{{reaction.emoji}} {{reaction.role.mention}} will expire `{}` after selection.".format(
|
||||
self.formatted
|
||||
)
|
||||
else:
|
||||
return "{reaction.emoji} {reaction.role.mention} will not expire."
|
||||
@@ -0,0 +1,590 @@
|
||||
import asyncio
|
||||
from codecs import ignore_errors
|
||||
import logging
|
||||
import traceback
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import List, Mapping, Optional
|
||||
from cachetools import LFUCache
|
||||
|
||||
import discord
|
||||
from discord import PartialEmoji
|
||||
|
||||
from meta import client
|
||||
from core import Lion
|
||||
from data import Row
|
||||
from data.conditions import THIS_SHARD
|
||||
from utils.lib import utc_now
|
||||
from settings import GuildSettings
|
||||
|
||||
from ..module import module
|
||||
from .data import reaction_role_messages, reaction_role_reactions
|
||||
from .settings import RoleMessageSettings, ReactionSettings
|
||||
from .expiry import schedule_expiry, cancel_expiry
|
||||
|
||||
|
||||
class ReactionRoleReaction:
|
||||
"""
|
||||
Light data class representing a reaction role reaction.
|
||||
"""
|
||||
__slots__ = ('reactionid', '_emoji', '_message', '_role')
|
||||
|
||||
def __init__(self, reactionid, message=None, **kwargs):
|
||||
self.reactionid = reactionid
|
||||
self._message: ReactionRoleMessage = None
|
||||
self._role = None
|
||||
self._emoji = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, messageid, roleid, emoji: PartialEmoji, message=None, **kwargs) -> 'ReactionRoleReaction':
|
||||
"""
|
||||
Create a new ReactionRoleReaction with the provided attributes.
|
||||
`emoji` sould be provided as a PartialEmoji.
|
||||
`kwargs` are passed transparently to the `insert` method.
|
||||
"""
|
||||
row = reaction_role_reactions.create_row(
|
||||
messageid=messageid,
|
||||
roleid=roleid,
|
||||
emoji_name=emoji.name,
|
||||
emoji_id=emoji.id,
|
||||
emoji_animated=emoji.animated,
|
||||
**kwargs
|
||||
)
|
||||
return cls(row.reactionid, message=message)
|
||||
|
||||
@property
|
||||
def emoji(self) -> PartialEmoji:
|
||||
if self._emoji is None:
|
||||
data = self.data
|
||||
self._emoji = PartialEmoji(
|
||||
name=data.emoji_name,
|
||||
animated=data.emoji_animated,
|
||||
id=data.emoji_id,
|
||||
)
|
||||
return self._emoji
|
||||
|
||||
@property
|
||||
def data(self) -> Row:
|
||||
return reaction_role_reactions.fetch(self.reactionid)
|
||||
|
||||
@property
|
||||
def settings(self) -> ReactionSettings:
|
||||
return ReactionSettings(self.reactionid)
|
||||
|
||||
@property
|
||||
def reaction_message(self):
|
||||
if self._message is None:
|
||||
self._message = ReactionRoleMessage.fetch(self.data.messageid)
|
||||
return self._message
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
if self._role is None:
|
||||
guild = self.reaction_message.guild
|
||||
if guild:
|
||||
self._role = guild.get_role(self.data.roleid)
|
||||
return self._role
|
||||
|
||||
|
||||
class ReactionRoleMessage:
|
||||
"""
|
||||
Light data class representing a reaction role message.
|
||||
Primarily acts as an interface to the corresponding Settings.
|
||||
"""
|
||||
__slots__ = ('messageid', '_message')
|
||||
|
||||
# Full live messageid cache for this client. Should always be up to date.
|
||||
_messages: Mapping[int, 'ReactionRoleMessage'] = {} # messageid -> associated Reaction message
|
||||
|
||||
# Reaction cache for the live messages. Least frequently used, will be fetched on demand.
|
||||
_reactions: Mapping[int, List[ReactionRoleReaction]] = LFUCache(1000) # messageid -> List of Reactions
|
||||
|
||||
# User-keyed locks so we only handle one reaction per user at a time
|
||||
_locks: Mapping[int, asyncio.Lock] = defaultdict(asyncio.Lock) # userid -> Lock
|
||||
|
||||
def __init__(self, messageid):
|
||||
self.messageid = messageid
|
||||
self._message = None
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, messageid) -> 'ReactionRoleMessage':
|
||||
"""
|
||||
Fetch the ReactionRoleMessage for the provided messageid.
|
||||
Returns None if the messageid is not registered.
|
||||
"""
|
||||
# Since the cache is assumed to be always up to date, just pass to fetch-from-cache.
|
||||
return cls._messages.get(messageid, None)
|
||||
|
||||
@classmethod
|
||||
def create(cls, messageid, guildid, channelid, **kwargs) -> 'ReactionRoleMessage':
|
||||
"""
|
||||
Create a ReactionRoleMessage with the given `messageid`.
|
||||
Other `kwargs` are passed transparently to `insert`.
|
||||
"""
|
||||
# Insert the data
|
||||
reaction_role_messages.create_row(
|
||||
messageid=messageid,
|
||||
guildid=guildid,
|
||||
channelid=channelid,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create the ReactionRoleMessage
|
||||
rmsg = cls(messageid)
|
||||
|
||||
# Add to the global cache
|
||||
cls._messages[messageid] = rmsg
|
||||
|
||||
# Return the constructed ReactionRoleMessage
|
||||
return rmsg
|
||||
|
||||
def delete(self):
|
||||
"""
|
||||
Delete this ReactionRoleMessage.
|
||||
"""
|
||||
# Remove message from cache
|
||||
self._messages.pop(self.messageid, None)
|
||||
|
||||
# Remove reactions from cache
|
||||
reactionids = [reaction.reactionid for reaction in self.reactions]
|
||||
[self._reactions.pop(reactionid, None) for reactionid in reactionids]
|
||||
|
||||
# Remove message from data
|
||||
reaction_role_messages.delete_where(messageid=self.messageid)
|
||||
|
||||
@property
|
||||
def data(self) -> Row:
|
||||
"""
|
||||
Data row associated with this Message.
|
||||
Passes directly to the RowTable cache.
|
||||
Should not generally be used directly, use the settings interface instead.
|
||||
"""
|
||||
return reaction_role_messages.fetch(self.messageid)
|
||||
|
||||
@property
|
||||
def settings(self):
|
||||
"""
|
||||
RoleMessageSettings associated to this Message.
|
||||
"""
|
||||
return RoleMessageSettings(self.messageid)
|
||||
|
||||
def refresh(self):
|
||||
"""
|
||||
Refresh the reaction cache for this message.
|
||||
Returns the generated `ReactionRoleReaction`s for convenience.
|
||||
"""
|
||||
# Fetch reactions and pre-populate reaction cache
|
||||
rows = reaction_role_reactions.fetch_rows_where(messageid=self.messageid, _extra="ORDER BY reactionid ASC")
|
||||
reactions = [ReactionRoleReaction(row.reactionid) for row in rows]
|
||||
self._reactions[self.messageid] = reactions
|
||||
return reactions
|
||||
|
||||
@property
|
||||
def reactions(self) -> List[ReactionRoleReaction]:
|
||||
"""
|
||||
Returns the list of active reactions for this message, as `ReactionRoleReaction`s.
|
||||
Lazily fetches the reactions from data if they have not been loaded.
|
||||
"""
|
||||
reactions = self._reactions.get(self.messageid, None)
|
||||
if reactions is None:
|
||||
reactions = self.refresh()
|
||||
return reactions
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""
|
||||
Whether this Message is enabled.
|
||||
Passes directly to data for efficiency.
|
||||
"""
|
||||
return self.data.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool):
|
||||
self.data.enabled = value
|
||||
|
||||
# Discord properties
|
||||
@property
|
||||
def guild(self) -> discord.Guild:
|
||||
return client.get_guild(self.data.guildid)
|
||||
|
||||
@property
|
||||
def channel(self) -> discord.TextChannel:
|
||||
return client.get_channel(self.data.channelid)
|
||||
|
||||
async def fetch_message(self) -> discord.Message:
|
||||
if self._message:
|
||||
return self._message
|
||||
|
||||
channel = self.channel
|
||||
if channel:
|
||||
try:
|
||||
self._message = await channel.fetch_message(self.messageid)
|
||||
return self._message
|
||||
except discord.NotFound:
|
||||
# The message no longer exists
|
||||
# TODO: Cache and data cleanup? Or allow moving after death?
|
||||
pass
|
||||
|
||||
@property
|
||||
def message(self) -> Optional[discord.Message]:
|
||||
return self._message
|
||||
|
||||
@property
|
||||
def message_link(self) -> str:
|
||||
"""
|
||||
Jump link tho the reaction message.
|
||||
"""
|
||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||
self.data.guildid,
|
||||
self.data.channelid,
|
||||
self.messageid
|
||||
)
|
||||
|
||||
# Event handlers
|
||||
async def process_raw_reaction_add(self, payload):
|
||||
"""
|
||||
Process a general reaction add payload.
|
||||
"""
|
||||
event_log = GuildSettings(self.guild.id).event_log
|
||||
async with self._locks[payload.user_id]:
|
||||
reaction = next((reaction for reaction in self.reactions if reaction.emoji == payload.emoji), None)
|
||||
if reaction:
|
||||
# User pressed a live reaction. Process!
|
||||
member = payload.member
|
||||
lion = Lion.fetch(member.guild.id, member.id)
|
||||
role = reaction.role
|
||||
if reaction.role and (role not in member.roles):
|
||||
# Required role check, make sure the user has the required role, if set.
|
||||
required_role = self.settings.required_role.value
|
||||
if required_role and required_role not in member.roles:
|
||||
# Silently remove their reaction
|
||||
try:
|
||||
message = await self.fetch_message()
|
||||
await message.remove_reaction(
|
||||
payload.emoji,
|
||||
member
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return
|
||||
|
||||
# Maximum check, check whether the user already has too many roles from this message.
|
||||
maximum = self.settings.maximum.value
|
||||
if maximum is not None:
|
||||
# Fetch the number of applicable roles the user has
|
||||
roleids = set(reaction.data.roleid for reaction in self.reactions)
|
||||
member_roleids = set(role.id for role in member.roles)
|
||||
if len(roleids.intersection(member_roleids)) >= maximum:
|
||||
# Notify the user
|
||||
embed = discord.Embed(
|
||||
title="Maximum group roles reached!",
|
||||
description=(
|
||||
"Couldn't give you **{}**, "
|
||||
"because you already have `{}` roles from this group!".format(
|
||||
role.name,
|
||||
maximum
|
||||
)
|
||||
)
|
||||
)
|
||||
# Silently try to notify the user
|
||||
try:
|
||||
await member.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
# Silently remove the reaction
|
||||
try:
|
||||
message = await self.fetch_message()
|
||||
await message.remove_reaction(
|
||||
payload.emoji,
|
||||
member
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
# Economy hook, check whether the user can pay for the role.
|
||||
price = reaction.settings.price.value
|
||||
if price and price > lion.coins:
|
||||
# They can't pay!
|
||||
# Build the can't pay embed
|
||||
embed = discord.Embed(
|
||||
title="Insufficient funds!",
|
||||
description="Sorry, **{}** costs `{}` coins, but you only have `{}`.".format(
|
||||
role.name,
|
||||
price,
|
||||
lion.coins
|
||||
),
|
||||
colour=discord.Colour.red()
|
||||
).set_footer(
|
||||
icon_url=self.guild.icon_url,
|
||||
text=self.guild.name
|
||||
).add_field(
|
||||
name="Jump Back",
|
||||
value="[Click here]({})".format(self.message_link)
|
||||
)
|
||||
# Try to send them the embed, ignore errors
|
||||
try:
|
||||
await member.send(
|
||||
embed=embed
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Remove their reaction, ignore errors
|
||||
try:
|
||||
message = await self.fetch_message()
|
||||
await message.remove_reaction(
|
||||
payload.emoji,
|
||||
member
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
# Add the role
|
||||
try:
|
||||
await member.add_roles(
|
||||
role,
|
||||
atomic=True,
|
||||
reason="Adding reaction role."
|
||||
)
|
||||
except discord.Forbidden:
|
||||
event_log.log(
|
||||
"Insufficient permissions to give {} the [reaction role]({}) {}".format(
|
||||
member.mention,
|
||||
self.message_link,
|
||||
role.mention,
|
||||
),
|
||||
title="Failed to add reaction role",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
except discord.HTTPException:
|
||||
event_log.log(
|
||||
"Something went wrong while adding the [reaction role]({}) "
|
||||
"{} to {}.".format(
|
||||
self.message_link,
|
||||
role.mention,
|
||||
member.mention
|
||||
),
|
||||
title="Failed to add reaction role",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
client.log(
|
||||
"Unexpected HTTPException encountered while adding '{}' (rid:{}) to "
|
||||
"user '{}' (uid:{}) in guild '{}' (gid:{}).\n{}".format(
|
||||
role.name,
|
||||
role.id,
|
||||
member,
|
||||
member.id,
|
||||
member.guild.name,
|
||||
member.guild.id,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="REACTION_ROLE_ADD",
|
||||
level=logging.WARNING
|
||||
)
|
||||
else:
|
||||
# Charge the user and notify them, if the price is set
|
||||
if price:
|
||||
lion.addCoins(-price)
|
||||
# Notify the user of their purchase
|
||||
embed = discord.Embed(
|
||||
title="Purchase successful!",
|
||||
description="You have purchased **{}** for `{}` coins!".format(
|
||||
role.name,
|
||||
price
|
||||
),
|
||||
colour=discord.Colour.green()
|
||||
).set_footer(
|
||||
icon_url=self.guild.icon_url,
|
||||
text=self.guild.name
|
||||
).add_field(
|
||||
name="Jump Back",
|
||||
value="[Click Here]({})".format(self.message_link)
|
||||
)
|
||||
try:
|
||||
await member.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Schedule the expiry, if required
|
||||
duration = reaction.settings.duration.value
|
||||
if duration:
|
||||
expiry = utc_now() + datetime.timedelta(seconds=duration)
|
||||
schedule_expiry(self.guild.id, member.id, role.id, expiry, reaction.reactionid)
|
||||
else:
|
||||
expiry = None
|
||||
|
||||
# Log the role modification if required
|
||||
if self.settings.log.value:
|
||||
event_log.log(
|
||||
"Added [reaction role]({}) {} "
|
||||
"to {}{}.{}".format(
|
||||
self.message_link,
|
||||
role.mention,
|
||||
member.mention,
|
||||
" for `{}` coins".format(price) if price else '',
|
||||
"\nThis role will expire at <t:{:.0f}>.".format(
|
||||
expiry.timestamp()
|
||||
) if expiry else ''
|
||||
),
|
||||
title="Reaction Role Added"
|
||||
)
|
||||
|
||||
async def process_raw_reaction_remove(self, payload):
|
||||
"""
|
||||
Process a general reaction remove payload.
|
||||
"""
|
||||
if self.settings.removable.value:
|
||||
event_log = GuildSettings(self.guild.id).event_log
|
||||
async with self._locks[payload.user_id]:
|
||||
reaction = next((reaction for reaction in self.reactions if reaction.emoji == payload.emoji), None)
|
||||
if reaction:
|
||||
# User removed a live reaction. Process!
|
||||
member = self.guild.get_member(payload.user_id)
|
||||
role = reaction.role
|
||||
if member and not member.bot and role and (role in member.roles):
|
||||
# Check whether they have the required role, if set
|
||||
required_role = self.settings.required_role.value
|
||||
if required_role and required_role not in member.roles:
|
||||
# Ignore the reaction removal
|
||||
return
|
||||
|
||||
try:
|
||||
await member.remove_roles(
|
||||
role,
|
||||
atomic=True,
|
||||
reason="Removing reaction role."
|
||||
)
|
||||
except discord.Forbidden:
|
||||
event_log.log(
|
||||
"Insufficient permissions to remove "
|
||||
"the [reaction role]({}) {} from {}".format(
|
||||
self.message_link,
|
||||
role.mention,
|
||||
member.mention,
|
||||
),
|
||||
title="Failed to remove reaction role",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
except discord.HTTPException:
|
||||
event_log.log(
|
||||
"Something went wrong while removing the [reaction role]({}) "
|
||||
"{} from {}.".format(
|
||||
self.message_link,
|
||||
role.mention,
|
||||
member.mention
|
||||
),
|
||||
title="Failed to remove reaction role",
|
||||
colour=discord.Colour.red()
|
||||
)
|
||||
client.log(
|
||||
"Unexpected HTTPException encountered while removing '{}' (rid:{}) from "
|
||||
"user '{}' (uid:{}) in guild '{}' (gid:{}).\n{}".format(
|
||||
role.name,
|
||||
role.id,
|
||||
member,
|
||||
member.id,
|
||||
member.guild.name,
|
||||
member.guild.id,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="REACTION_ROLE_RM",
|
||||
level=logging.WARNING
|
||||
)
|
||||
else:
|
||||
# Economy hook, handle refund if required
|
||||
price = reaction.settings.price.value
|
||||
refund = self.settings.refunds.value
|
||||
if price and refund:
|
||||
# Give the user the refund
|
||||
lion = Lion.fetch(self.guild.id, member.id)
|
||||
lion.addCoins(price)
|
||||
|
||||
# Notify the user
|
||||
embed = discord.Embed(
|
||||
title="Role sold",
|
||||
description=(
|
||||
"You sold the role **{}** for `{}` coins.".format(
|
||||
role.name,
|
||||
price
|
||||
)
|
||||
),
|
||||
colour=discord.Colour.green()
|
||||
).set_footer(
|
||||
icon_url=self.guild.icon_url,
|
||||
text=self.guild.name
|
||||
).add_field(
|
||||
name="Jump Back",
|
||||
value="[Click Here]({})".format(self.message_link)
|
||||
)
|
||||
try:
|
||||
await member.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Log role removal if required
|
||||
if self.settings.log.value:
|
||||
event_log.log(
|
||||
"Removed [reaction role]({}) {} "
|
||||
"from {}.".format(
|
||||
self.message_link,
|
||||
role.mention,
|
||||
member.mention
|
||||
),
|
||||
title="Reaction Role Removed"
|
||||
)
|
||||
|
||||
# Cancel any existing expiry
|
||||
cancel_expiry(self.guild.id, member.id, role.id)
|
||||
|
||||
|
||||
# TODO: Make all the embeds a bit nicer, and maybe make a consistent interface for them
|
||||
# TODO: Handle RawMessageDelete event
|
||||
# TODO: Handle permission errors when fetching message in config
|
||||
|
||||
@client.add_after_event('raw_reaction_add')
|
||||
async def reaction_role_add(client, payload):
|
||||
reaction_message = ReactionRoleMessage.fetch(payload.message_id)
|
||||
if payload.guild_id and payload.user_id != client.user.id and reaction_message and reaction_message.enabled:
|
||||
try:
|
||||
await reaction_message.process_raw_reaction_add(payload)
|
||||
except Exception:
|
||||
# Unknown exception, catch and log it.
|
||||
client.log(
|
||||
"Unhandled exception while handling reaction message payload: {}\n{}".format(
|
||||
payload,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="REACTION_ROLE_ADD",
|
||||
level=logging.ERROR
|
||||
)
|
||||
|
||||
|
||||
@client.add_after_event('raw_reaction_remove')
|
||||
async def reaction_role_remove(client, payload):
|
||||
reaction_message = ReactionRoleMessage.fetch(payload.message_id)
|
||||
if payload.guild_id and reaction_message and reaction_message.enabled:
|
||||
try:
|
||||
await reaction_message.process_raw_reaction_remove(payload)
|
||||
except Exception:
|
||||
# Unknown exception, catch and log it.
|
||||
client.log(
|
||||
"Unhandled exception while handling reaction message payload: {}\n{}".format(
|
||||
payload,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="REACTION_ROLE_RM",
|
||||
level=logging.ERROR
|
||||
)
|
||||
|
||||
|
||||
@module.init_task
|
||||
def load_reaction_roles(client):
|
||||
"""
|
||||
Load the ReactionRoleMessages.
|
||||
"""
|
||||
rows = reaction_role_messages.fetch_rows_where(guildid=THIS_SHARD)
|
||||
ReactionRoleMessage._messages = {row.messageid: ReactionRoleMessage(row.messageid) for row in rows}
|
||||
65
src/modules/pending-rewrite/guild_admin/statreset.py
Normal file
65
src/modules/pending-rewrite/guild_admin/statreset.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from io import StringIO
|
||||
|
||||
import discord
|
||||
from wards import guild_admin
|
||||
from data import tables
|
||||
from core import Lion
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
@module.cmd("studyreset",
|
||||
desc="Perform a reset of the server's study statistics.",
|
||||
group="Guild Admin")
|
||||
@guild_admin()
|
||||
async def cmd_statreset(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}studyreset
|
||||
Description:
|
||||
Perform a complete reset of the server's study statistics.
|
||||
That is, deletes the tracked time of all members and removes their study badges.
|
||||
|
||||
This may be used to set "seasons" of study.
|
||||
|
||||
Before the reset, I will send a csv file with the current member statistics.
|
||||
|
||||
**This is not reversible.**
|
||||
"""
|
||||
if not await ctx.ask("Are you sure you want to reset the study time and badges for all members? "
|
||||
"**THIS IS NOT REVERSIBLE!**"):
|
||||
return
|
||||
# Build the data csv
|
||||
rows = tables.lions.select_where(
|
||||
select_columns=('userid', 'tracked_time', 'coins', 'workout_count', 'b.roleid AS badge_roleid'),
|
||||
_extra=(
|
||||
"LEFT JOIN study_badges b ON last_study_badgeid = b.badgeid "
|
||||
"WHERE members.guildid={}"
|
||||
).format(ctx.guild.id)
|
||||
)
|
||||
header = "userid, tracked_time, coins, workouts, rank_roleid\n"
|
||||
csv_rows = [
|
||||
','.join(str(data) for data in row)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
with StringIO() as stats_file:
|
||||
stats_file.write(header)
|
||||
stats_file.write('\n'.join(csv_rows))
|
||||
stats_file.seek(0)
|
||||
|
||||
out_file = discord.File(stats_file, filename="guild_{}_member_statistics.csv".format(ctx.guild.id))
|
||||
await ctx.reply(file=out_file)
|
||||
|
||||
# Reset the statistics
|
||||
tables.lions.update_where(
|
||||
{'tracked_time': 0},
|
||||
guildid=ctx.guild.id
|
||||
)
|
||||
|
||||
Lion.sync()
|
||||
|
||||
await ctx.embed_reply(
|
||||
"The member study times have been reset!\n"
|
||||
"(It may take a while for the studybadges to update.)"
|
||||
)
|
||||
7
src/modules/pending-rewrite/meta/__init__.py
Normal file
7
src/modules/pending-rewrite/meta/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# flake8: noqa
|
||||
from .module import module
|
||||
|
||||
from . import help
|
||||
from . import links
|
||||
from . import nerd
|
||||
from . import join_message
|
||||
237
src/modules/pending-rewrite/meta/help.py
Normal file
237
src/modules/pending-rewrite/meta/help.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import discord
|
||||
from cmdClient.checks import is_owner
|
||||
|
||||
from utils.lib import prop_tabulate
|
||||
from utils import interactive, ctx_addons # noqa
|
||||
from wards import is_guild_admin
|
||||
|
||||
from .module import module
|
||||
from .lib import guide_link
|
||||
|
||||
|
||||
new_emoji = " 🆕"
|
||||
new_commands = {'botconfig', 'sponsors'}
|
||||
|
||||
# Set the command groups to appear in the help
|
||||
group_hints = {
|
||||
'Pomodoro': "*Stay in sync with your friends using our timers!*",
|
||||
'Productivity': "*Use these to help you stay focused and productive!*",
|
||||
'Statistics': "*StudyLion leaderboards and study statistics.*",
|
||||
'Economy': "*Buy, sell, and trade with your hard-earned coins!*",
|
||||
'Personal Settings': "*Tell me about yourself!*",
|
||||
'Guild Admin': "*Dangerous administration commands!*",
|
||||
'Guild Configuration': "*Control how I behave in your server.*",
|
||||
'Meta': "*Information about me!*",
|
||||
'Support Us': "*Support the team and keep the project alive by using LionGems!*"
|
||||
}
|
||||
|
||||
standard_group_order = (
|
||||
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings', 'Meta'),
|
||||
)
|
||||
|
||||
mod_group_order = (
|
||||
('Moderation', 'Meta'),
|
||||
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
|
||||
)
|
||||
|
||||
admin_group_order = (
|
||||
('Guild Admin', 'Guild Configuration', 'Moderation', 'Meta'),
|
||||
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
|
||||
)
|
||||
|
||||
bot_admin_group_order = (
|
||||
('Bot Admin', 'Guild Admin', 'Guild Configuration', 'Moderation', 'Meta'),
|
||||
('Pomodoro', 'Productivity', 'Support Us', 'Statistics', 'Economy', 'Personal Settings')
|
||||
)
|
||||
|
||||
# Help embed format
|
||||
# TODO: Add config fields for this
|
||||
title = "StudyLion Command List"
|
||||
header = """
|
||||
[StudyLion](https://bot.studylions.com/) is a fully featured study assistant \
|
||||
that tracks your study time and offers productivity tools \
|
||||
such as to-do lists, task reminders, private study rooms, group accountability sessions, and much much more.\n
|
||||
Use `{{ctx.best_prefix}}help <command>` (e.g. `{{ctx.best_prefix}}help send`) to learn how to use each command, \
|
||||
or [click here]({guide_link}) for a comprehensive tutorial.
|
||||
""".format(guide_link=guide_link)
|
||||
|
||||
|
||||
@module.cmd("help",
|
||||
group="Meta",
|
||||
desc="StudyLion command list.",
|
||||
aliases=('man', 'ls', 'list'))
|
||||
async def cmd_help(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}help [cmdname]
|
||||
Description:
|
||||
When used with no arguments, displays a list of commands with brief descriptions.
|
||||
Otherwise, shows documentation for the provided command.
|
||||
Examples:
|
||||
{prefix}help
|
||||
{prefix}help top
|
||||
{prefix}help timezone
|
||||
"""
|
||||
if ctx.arg_str:
|
||||
# Attempt to fetch the command
|
||||
command = ctx.client.cmd_names.get(ctx.arg_str.strip(), None)
|
||||
if command is None:
|
||||
return await ctx.error_reply(
|
||||
("Command `{}` not found!\n"
|
||||
"Write `{}help` to see a list of commands.").format(ctx.args, ctx.best_prefix)
|
||||
)
|
||||
|
||||
smart_help = getattr(command, 'smart_help', None)
|
||||
if smart_help is not None:
|
||||
return await smart_help(ctx)
|
||||
|
||||
help_fields = command.long_help.copy()
|
||||
help_map = {field_name: i for i, (field_name, _) in enumerate(help_fields)}
|
||||
|
||||
if not help_map:
|
||||
return await ctx.reply("No documentation has been written for this command yet!")
|
||||
|
||||
field_pages = [[]]
|
||||
page_fields = field_pages[0]
|
||||
for name, pos in help_map.items():
|
||||
if name.endswith("``"):
|
||||
# Handle codeline help fields
|
||||
page_fields.append((
|
||||
name.strip("`"),
|
||||
"`{}`".format('`\n`'.join(help_fields[pos][1].splitlines()))
|
||||
))
|
||||
elif name.endswith(":"):
|
||||
# Handle property/value help fields
|
||||
lines = help_fields[pos][1].splitlines()
|
||||
|
||||
names = []
|
||||
values = []
|
||||
for line in lines:
|
||||
split = line.split(":", 1)
|
||||
names.append(split[0] if len(split) > 1 else "")
|
||||
values.append(split[-1])
|
||||
|
||||
page_fields.append((
|
||||
name.strip(':'),
|
||||
prop_tabulate(names, values)
|
||||
))
|
||||
elif name == "Related":
|
||||
# Handle the related field
|
||||
names = [cmd_name.strip() for cmd_name in help_fields[pos][1].split(',')]
|
||||
names.sort(key=len)
|
||||
values = [
|
||||
(getattr(ctx.client.cmd_names.get(cmd_name, None), 'desc', '') or '').format(ctx=ctx)
|
||||
for cmd_name in names
|
||||
]
|
||||
page_fields.append((
|
||||
name,
|
||||
prop_tabulate(names, values)
|
||||
))
|
||||
elif name == "PAGEBREAK":
|
||||
page_fields = []
|
||||
field_pages.append(page_fields)
|
||||
else:
|
||||
page_fields.append((name, help_fields[pos][1]))
|
||||
|
||||
# Build the aliases
|
||||
aliases = getattr(command, 'aliases', [])
|
||||
alias_str = "(Aliases `{}`.)".format("`, `".join(aliases)) if aliases else ""
|
||||
|
||||
# Build the embeds
|
||||
pages = []
|
||||
for i, page_fields in enumerate(field_pages):
|
||||
embed = discord.Embed(
|
||||
title="`{}` command documentation. {}".format(
|
||||
command.name,
|
||||
alias_str
|
||||
),
|
||||
colour=discord.Colour(0x9b59b6)
|
||||
)
|
||||
for fieldname, fieldvalue in page_fields:
|
||||
embed.add_field(
|
||||
name=fieldname,
|
||||
value=fieldvalue.format(ctx=ctx, prefix=ctx.best_prefix),
|
||||
inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(
|
||||
text="{}\n[optional] and <required> denote optional and required arguments, respectively.".format(
|
||||
"Page {} of {}".format(i + 1, len(field_pages)) if len(field_pages) > 1 else '',
|
||||
)
|
||||
)
|
||||
pages.append(embed)
|
||||
|
||||
# Post the embed
|
||||
await ctx.pager(pages)
|
||||
else:
|
||||
# Build the command groups
|
||||
cmd_groups = {}
|
||||
for command in ctx.client.cmds:
|
||||
# Get the command group
|
||||
group = getattr(command, 'group', "Misc")
|
||||
cmd_group = cmd_groups.get(group, [])
|
||||
if not cmd_group:
|
||||
cmd_groups[group] = cmd_group
|
||||
|
||||
# Add the command name and description to the group
|
||||
cmd_group.append(
|
||||
(command.name, (getattr(command, 'desc', '') + (new_emoji if command.name in new_commands else '')))
|
||||
)
|
||||
|
||||
# Add any required aliases
|
||||
for alias, desc in getattr(command, 'help_aliases', {}).items():
|
||||
cmd_group.append((alias, desc))
|
||||
|
||||
# Turn the command groups into strings
|
||||
stringy_cmd_groups = {}
|
||||
for group_name, cmd_group in cmd_groups.items():
|
||||
cmd_group.sort(key=lambda tup: len(tup[0]))
|
||||
if ctx.alias == 'ls':
|
||||
stringy_cmd_groups[group_name] = ', '.join(
|
||||
f"`{name}`" for name, _ in cmd_group
|
||||
)
|
||||
else:
|
||||
stringy_cmd_groups[group_name] = prop_tabulate(*zip(*cmd_group))
|
||||
|
||||
# Now put everything into a bunch of embeds
|
||||
if await is_owner.run(ctx):
|
||||
group_order = bot_admin_group_order
|
||||
elif ctx.guild:
|
||||
if is_guild_admin(ctx.author):
|
||||
group_order = admin_group_order
|
||||
elif ctx.guild_settings.mod_role.value in ctx.author.roles:
|
||||
group_order = mod_group_order
|
||||
else:
|
||||
group_order = standard_group_order
|
||||
else:
|
||||
group_order = admin_group_order
|
||||
|
||||
help_embeds = []
|
||||
for page_groups in group_order:
|
||||
embed = discord.Embed(
|
||||
description=header.format(ctx=ctx),
|
||||
colour=discord.Colour(0x9b59b6),
|
||||
title=title
|
||||
)
|
||||
for group in page_groups:
|
||||
group_hint = group_hints.get(group, '').format(ctx=ctx)
|
||||
group_str = stringy_cmd_groups.get(group, None)
|
||||
if group_str:
|
||||
embed.add_field(
|
||||
name=group,
|
||||
value="{}\n{}".format(group_hint, group_str).format(ctx=ctx),
|
||||
inline=False
|
||||
)
|
||||
help_embeds.append(embed)
|
||||
|
||||
# Add the page numbers
|
||||
for i, embed in enumerate(help_embeds):
|
||||
embed.set_footer(text="Page {}/{}".format(i+1, len(help_embeds)))
|
||||
|
||||
# Send the embeds
|
||||
if help_embeds:
|
||||
await ctx.pager(help_embeds)
|
||||
else:
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(description=header, colour=discord.Colour(0x9b59b6))
|
||||
)
|
||||
50
src/modules/pending-rewrite/meta/join_message.py
Normal file
50
src/modules/pending-rewrite/meta/join_message.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import discord
|
||||
|
||||
from cmdClient import cmdClient
|
||||
|
||||
from meta import client, conf
|
||||
from .lib import guide_link, animation_link
|
||||
|
||||
|
||||
message = """
|
||||
Thank you for inviting me to your community.
|
||||
Get started by typing `{prefix}help` to see my commands, and `{prefix}config info` \
|
||||
to read about my configuration options!
|
||||
|
||||
To learn how to configure me and use all of my features, \
|
||||
make sure to [click here]({guide_link}) to read our full setup guide.
|
||||
|
||||
Remember, if you need any help configuring me, \
|
||||
want to suggest a feature, report a bug and stay updated, \
|
||||
make sure to join our main support and study server by [clicking here]({support_link}).
|
||||
|
||||
Best of luck with your studies!
|
||||
|
||||
""".format(
|
||||
guide_link=guide_link,
|
||||
support_link=conf.bot.get('support_link'),
|
||||
prefix=client.prefix
|
||||
)
|
||||
|
||||
|
||||
@client.add_after_event('guild_join', priority=0)
|
||||
async def post_join_message(client: cmdClient, guild: discord.Guild):
|
||||
try:
|
||||
await guild.me.edit(nick="Leo")
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
if (channel := guild.system_channel) and channel.permissions_for(guild.me).embed_links:
|
||||
embed = discord.Embed(
|
||||
description=message
|
||||
)
|
||||
embed.set_author(
|
||||
name="Hello everyone! My name is Leo, the StudyLion!",
|
||||
icon_url="https://cdn.discordapp.com/emojis/933610591459872868.webp"
|
||||
)
|
||||
embed.set_image(url=animation_link)
|
||||
try:
|
||||
await channel.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
# Something went wrong sending the hi message
|
||||
# Not much we can do about this
|
||||
pass
|
||||
5
src/modules/pending-rewrite/meta/lib.py
Normal file
5
src/modules/pending-rewrite/meta/lib.py
Normal file
@@ -0,0 +1,5 @@
|
||||
guide_link = "https://discord.studylions.com/tutorial"
|
||||
|
||||
animation_link = (
|
||||
"https://media.discordapp.net/attachments/879412267731542047/926837189814419486/ezgif.com-resize.gif"
|
||||
)
|
||||
57
src/modules/pending-rewrite/meta/links.py
Normal file
57
src/modules/pending-rewrite/meta/links.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import discord
|
||||
|
||||
from meta import conf
|
||||
|
||||
from LionContext import LionContext as Context
|
||||
|
||||
from .module import module
|
||||
from .lib import guide_link
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"support",
|
||||
group="Meta",
|
||||
desc=f"Have a question? Join my [support server]({conf.bot.get('support_link')})"
|
||||
)
|
||||
async def cmd_support(ctx: Context):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}support
|
||||
Description:
|
||||
Replies with an invite link to my support server.
|
||||
"""
|
||||
await ctx.reply(
|
||||
f"Click here to join my support server: {conf.bot.get('support_link')}"
|
||||
)
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"invite",
|
||||
group="Meta",
|
||||
desc=f"[Invite me]({conf.bot.get('invite_link')}) to your server so I can help your members stay productive!"
|
||||
)
|
||||
async def cmd_invite(ctx: Context):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}invite
|
||||
Description:
|
||||
Replies with my invite link so you can add me to your server.
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description=f"Click here to add me to your server: {conf.bot.get('invite_link')}"
|
||||
)
|
||||
embed.add_field(
|
||||
name="Setup tips",
|
||||
value=(
|
||||
"Remember to check out `{prefix}help` for the full command list, "
|
||||
"and `{prefix}config info` for the configuration options.\n"
|
||||
"[Click here]({guide}) for our comprehensive setup tutorial, and if you still have questions you can "
|
||||
"join our support server [here]({support}) to talk to our friendly support team!"
|
||||
).format(
|
||||
prefix=ctx.best_prefix,
|
||||
support=conf.bot.get('support_link'),
|
||||
guide=guide_link
|
||||
)
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
3
src/modules/pending-rewrite/meta/module.py
Normal file
3
src/modules/pending-rewrite/meta/module.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from LionModule import LionModule
|
||||
|
||||
module = LionModule("Meta")
|
||||
144
src/modules/pending-rewrite/meta/nerd.py
Normal file
144
src/modules/pending-rewrite/meta/nerd.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import datetime
|
||||
import asyncio
|
||||
import discord
|
||||
import psutil
|
||||
import sys
|
||||
import gc
|
||||
|
||||
from data import NOTNULL
|
||||
from data.queries import select_where
|
||||
from utils.lib import prop_tabulate, utc_now
|
||||
|
||||
from LionContext import LionContext as Context
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
process = psutil.Process()
|
||||
process.cpu_percent()
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"nerd",
|
||||
group="Meta",
|
||||
desc="Information and statistics about me!"
|
||||
)
|
||||
async def cmd_nerd(ctx: Context):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}nerd
|
||||
Description:
|
||||
View nerdy information and statistics about me!
|
||||
"""
|
||||
# Create embed
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
title="Nerd Panel",
|
||||
description=(
|
||||
"Hi! I'm [StudyLion]({studylion}), a study management bot owned by "
|
||||
"[Ari Horesh]({ari}) and developed by [Conatum#5317]({cona}), with [contributors]({github})."
|
||||
).format(
|
||||
studylion="http://studylions.com/",
|
||||
ari="https://arihoresh.com/",
|
||||
cona="https://github.com/Intery",
|
||||
github="https://github.com/StudyLions/StudyLion"
|
||||
)
|
||||
)
|
||||
|
||||
# ----- Study stats -----
|
||||
# Current studying statistics
|
||||
current_students, current_channels, current_guilds= (
|
||||
ctx.client.data.current_sessions.select_one_where(
|
||||
select_columns=(
|
||||
"COUNT(*) AS studying_count",
|
||||
"COUNT(DISTINCT(channelid)) AS channel_count",
|
||||
"COUNT(DISTINCT(guildid)) AS guild_count"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Past studying statistics
|
||||
past_sessions, past_students, past_duration, past_guilds = ctx.client.data.session_history.select_one_where(
|
||||
select_columns=(
|
||||
"COUNT(*) AS session_count",
|
||||
"COUNT(DISTINCT(userid)) AS user_count",
|
||||
"SUM(duration) / 3600 AS total_hours",
|
||||
"COUNT(DISTINCT(guildid)) AS guild_count"
|
||||
)
|
||||
)
|
||||
|
||||
# Tasklist statistics
|
||||
tasks = ctx.client.data.tasklist.select_one_where(
|
||||
select_columns=(
|
||||
'COUNT(*)'
|
||||
)
|
||||
)[0]
|
||||
|
||||
tasks_completed = ctx.client.data.tasklist.select_one_where(
|
||||
completed_at=NOTNULL,
|
||||
select_columns=(
|
||||
'COUNT(*)'
|
||||
)
|
||||
)[0]
|
||||
|
||||
# Timers
|
||||
timer_count, timer_guilds = ctx.client.data.timers.select_one_where(
|
||||
select_columns=("COUNT(*)", "COUNT(DISTINCT(guildid))")
|
||||
)
|
||||
|
||||
study_fields = {
|
||||
"Currently": f"`{current_students}` people working in `{current_channels}` rooms of `{current_guilds}` guilds",
|
||||
"Recorded": f"`{past_duration}` hours from `{past_students}` people across `{past_sessions}` sessions",
|
||||
"Tasks": f"`{tasks_completed}` out of `{tasks}` tasks completed",
|
||||
"Timers": f"`{timer_count}` timers running in `{timer_guilds}` communities"
|
||||
}
|
||||
study_table = prop_tabulate(*zip(*study_fields.items()))
|
||||
|
||||
# ----- Shard statistics -----
|
||||
shard_number = ctx.client.shard_id
|
||||
shard_count = ctx.client.shard_count
|
||||
guilds = len(ctx.client.guilds)
|
||||
member_count = sum(guild.member_count for guild in ctx.client.guilds)
|
||||
commands = len(ctx.client.cmds)
|
||||
aliases = len(ctx.client.cmd_names)
|
||||
dpy_version = discord.__version__
|
||||
py_version = sys.version.split()[0]
|
||||
data_version, data_time, _ = select_where(
|
||||
"VersionHistory",
|
||||
_extra="ORDER BY time DESC LIMIT 1"
|
||||
)[0]
|
||||
data_timestamp = int(data_time.replace(tzinfo=datetime.timezone.utc).timestamp())
|
||||
|
||||
shard_fields = {
|
||||
"Shard": f"`{shard_number}` of `{shard_count}`",
|
||||
"Guilds": f"`{guilds}` servers with `{member_count}` members (on this shard)",
|
||||
"Commands": f"`{commands}` commands with `{aliases}` keywords",
|
||||
"Version": f"`v{data_version}`, last updated <t:{data_timestamp}:F>",
|
||||
"Py version": f"`{py_version}` running discord.py `{dpy_version}`"
|
||||
}
|
||||
shard_table = prop_tabulate(*zip(*shard_fields.items()))
|
||||
|
||||
|
||||
# ----- Execution statistics -----
|
||||
running_commands = len(ctx.client.active_contexts)
|
||||
tasks = len(asyncio.all_tasks())
|
||||
objects = len(gc.get_objects())
|
||||
cpu_percent = process.cpu_percent()
|
||||
mem_percent = int(process.memory_percent())
|
||||
uptime = int(utc_now().timestamp() - process.create_time())
|
||||
|
||||
execution_fields = {
|
||||
"Running": f"`{running_commands}` commands",
|
||||
"Waiting for": f"`{tasks}` tasks to complete",
|
||||
"Objects": f"`{objects}` loaded in memory",
|
||||
"Usage": f"`{cpu_percent}%` CPU, `{mem_percent}%` MEM",
|
||||
"Uptime": f"`{uptime // (24 * 3600)}` days, `{uptime // 3600 % 24:02}:{uptime // 60 % 60:02}:{uptime % 60:02}`"
|
||||
}
|
||||
execution_table = prop_tabulate(*zip(*execution_fields.items()))
|
||||
|
||||
# ----- Combine and output -----
|
||||
embed.add_field(name="Study Stats", value=study_table, inline=False)
|
||||
embed.add_field(name=f"Shard Info", value=shard_table, inline=False)
|
||||
embed.add_field(name=f"Process Stats", value=execution_table, inline=False)
|
||||
|
||||
await ctx.reply(embed=embed)
|
||||
9
src/modules/pending-rewrite/moderation/__init__.py
Normal file
9
src/modules/pending-rewrite/moderation/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .module import module
|
||||
|
||||
from . import data
|
||||
from . import admin
|
||||
|
||||
from . import tickets
|
||||
from . import video
|
||||
|
||||
from . import commands
|
||||
109
src/modules/pending-rewrite/moderation/admin.py
Normal file
109
src/modules/pending-rewrite/moderation/admin.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import discord
|
||||
|
||||
from settings import GuildSettings, GuildSetting
|
||||
from wards import guild_admin
|
||||
|
||||
import settings
|
||||
|
||||
from .data import studyban_durations
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class mod_log(settings.Channel, GuildSetting):
|
||||
category = "Moderation"
|
||||
|
||||
attr_name = 'mod_log'
|
||||
_data_column = 'mod_log_channel'
|
||||
|
||||
display_name = "mod_log"
|
||||
desc = "Moderation event logging channel."
|
||||
|
||||
long_desc = (
|
||||
"Channel to post moderation tickets.\n"
|
||||
"These are produced when a manual or automatic moderation action is performed on a member. "
|
||||
"This channel acts as a more context rich moderation history source than the audit log."
|
||||
)
|
||||
|
||||
_chan_type = discord.ChannelType.text
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Moderation tickets will be posted to {}.".format(self.formatted)
|
||||
else:
|
||||
return "The moderation log has been unset."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class studyban_role(settings.Role, GuildSetting):
|
||||
category = "Moderation"
|
||||
|
||||
attr_name = 'studyban_role'
|
||||
_data_column = 'studyban_role'
|
||||
|
||||
display_name = "studyban_role"
|
||||
desc = "The role given to members to prevent them from using server study features."
|
||||
|
||||
long_desc = (
|
||||
"This role is to be given to members to prevent them from using the server's study features.\n"
|
||||
"Typically, this role should act as a 'partial mute', and prevent the user from joining study voice channels, "
|
||||
"or participating in study text channels.\n"
|
||||
"It will be given automatically after study related offences, "
|
||||
"such as not enabling video in the video-only channels."
|
||||
)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The study ban role is now {}.".format(self.formatted)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class studyban_durations(settings.SettingList, settings.ListData, settings.Setting):
|
||||
category = "Moderation"
|
||||
|
||||
attr_name = 'studyban_durations'
|
||||
|
||||
_table_interface = studyban_durations
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'duration'
|
||||
_order_column = "rowid"
|
||||
|
||||
_default = [
|
||||
5 * 60,
|
||||
60 * 60,
|
||||
6 * 60 * 60,
|
||||
24 * 60 * 60,
|
||||
168 * 60 * 60,
|
||||
720 * 60 * 60
|
||||
]
|
||||
|
||||
_setting = settings.Duration
|
||||
|
||||
write_ward = guild_admin
|
||||
display_name = "studyban_durations"
|
||||
desc = "Sequence of durations for automatic study bans."
|
||||
|
||||
long_desc = (
|
||||
"This sequence describes how long a member will be automatically study-banned for "
|
||||
"after committing a study-related offence (such as not enabling their video in video only channels).\n"
|
||||
"If the sequence is `1d, 7d, 30d`, for example, the member will be study-banned "
|
||||
"for `1d` on their first offence, `7d` on their second offence, and `30d` on their third. "
|
||||
"On their fourth offence, they will not be unbanned.\n"
|
||||
"This does not count pardoned offences."
|
||||
)
|
||||
accepts = (
|
||||
"Comma separated list of durations in days/hours/minutes/seconds, for example `12h, 1d, 7d, 30d`."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire objects
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The automatic study ban durations are now {}.".format(self.formatted)
|
||||
else:
|
||||
return "Automatic study bans will never be reverted."
|
||||
|
||||
|
||||
448
src/modules/pending-rewrite/moderation/commands.py
Normal file
448
src/modules/pending-rewrite/moderation/commands.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Shared commands for the moderation module.
|
||||
"""
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import discord
|
||||
|
||||
from cmdClient.lib import ResponseTimedOut
|
||||
from wards import guild_moderator
|
||||
|
||||
from .module import module
|
||||
from .tickets import Ticket, TicketType, TicketState
|
||||
|
||||
|
||||
type_accepts = {
|
||||
'note': TicketType.NOTE,
|
||||
'notes': TicketType.NOTE,
|
||||
'studyban': TicketType.STUDY_BAN,
|
||||
'studybans': TicketType.STUDY_BAN,
|
||||
'warn': TicketType.WARNING,
|
||||
'warns': TicketType.WARNING,
|
||||
'warning': TicketType.WARNING,
|
||||
'warnings': TicketType.WARNING,
|
||||
}
|
||||
|
||||
type_formatted = {
|
||||
TicketType.NOTE: 'NOTE',
|
||||
TicketType.STUDY_BAN: 'STUDYBAN',
|
||||
TicketType.WARNING: 'WARNING',
|
||||
}
|
||||
|
||||
type_summary_formatted = {
|
||||
TicketType.NOTE: 'note',
|
||||
TicketType.STUDY_BAN: 'studyban',
|
||||
TicketType.WARNING: 'warning',
|
||||
}
|
||||
|
||||
state_formatted = {
|
||||
TicketState.OPEN: 'ACTIVE',
|
||||
TicketState.EXPIRING: 'TEMP',
|
||||
TicketState.EXPIRED: 'EXPIRED',
|
||||
TicketState.PARDONED: 'PARDONED'
|
||||
}
|
||||
|
||||
state_summary_formatted = {
|
||||
TicketState.OPEN: 'Active',
|
||||
TicketState.EXPIRING: 'Temporary',
|
||||
TicketState.EXPIRED: 'Expired',
|
||||
TicketState.REVERTED: 'Manually Reverted',
|
||||
TicketState.PARDONED: 'Pardoned'
|
||||
}
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"tickets",
|
||||
group="Moderation",
|
||||
desc="View and filter the server moderation tickets.",
|
||||
flags=('active', 'type=')
|
||||
)
|
||||
@guild_moderator()
|
||||
async def cmd_tickets(ctx, flags):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}tickets [@user] [--type <type>] [--active]
|
||||
Description:
|
||||
Display and optionally filter the moderation event history in this guild.
|
||||
Flags::
|
||||
type: Filter by ticket type. See **Ticket Types** below.
|
||||
active: Only show in-effect tickets (i.e. hide expired and pardoned ones).
|
||||
Ticket Types::
|
||||
note: Moderation notes.
|
||||
warn: Moderation warnings, both manual and automatic.
|
||||
studyban: Bans from using study features from abusing the study system.
|
||||
blacklist: Complete blacklisting from using my commands.
|
||||
Ticket States::
|
||||
Active: Active tickets that will not automatically expire.
|
||||
Temporary: Active tickets that will automatically expire after a set duration.
|
||||
Expired: Tickets that have automatically expired.
|
||||
Reverted: Tickets with actions that have been reverted.
|
||||
Pardoned: Tickets that have been pardoned and no longer apply to the user.
|
||||
Examples:
|
||||
{prefix}tickets {ctx.guild.owner.mention} --type warn --active
|
||||
"""
|
||||
# Parse filter fields
|
||||
# First the user
|
||||
if ctx.args:
|
||||
userstr = ctx.args.strip('<@!&> ')
|
||||
if not userstr.isdigit():
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{prefix}tickets [@user] [--type <type>] [--active]`.\n"
|
||||
"Please provide the `user` as a mention or id!".format(prefix=ctx.best_prefix)
|
||||
)
|
||||
filter_userid = int(userstr)
|
||||
else:
|
||||
filter_userid = None
|
||||
|
||||
if flags['type']:
|
||||
typestr = flags['type'].lower()
|
||||
if typestr not in type_accepts:
|
||||
return await ctx.error_reply(
|
||||
"Please see `{prefix}help tickets` for the valid ticket types!".format(prefix=ctx.best_prefix)
|
||||
)
|
||||
filter_type = type_accepts[typestr]
|
||||
else:
|
||||
filter_type = None
|
||||
|
||||
filter_active = flags['active']
|
||||
|
||||
# Build the filter arguments
|
||||
filters = {'guildid': ctx.guild.id}
|
||||
if filter_userid:
|
||||
filters['targetid'] = filter_userid
|
||||
if filter_type:
|
||||
filters['ticket_type'] = filter_type
|
||||
if filter_active:
|
||||
filters['ticket_state'] = [TicketState.OPEN, TicketState.EXPIRING]
|
||||
|
||||
# Fetch the tickets with these filters
|
||||
tickets = Ticket.fetch_tickets(**filters)
|
||||
|
||||
if not tickets:
|
||||
if filters:
|
||||
return await ctx.embed_reply("There are no tickets with these criteria!")
|
||||
else:
|
||||
return await ctx.embed_reply("There are no moderation tickets in this server!")
|
||||
|
||||
tickets = sorted(tickets, key=lambda ticket: ticket.data.guild_ticketid, reverse=True)
|
||||
ticket_map = {ticket.data.guild_ticketid: ticket for ticket in tickets}
|
||||
|
||||
# Build the format string based on the filters
|
||||
components = []
|
||||
# Ticket id with link to message in mod log
|
||||
components.append("[#{ticket.data.guild_ticketid}]({ticket.link})")
|
||||
# Ticket creation date
|
||||
components.append("<t:{timestamp:.0f}:d>")
|
||||
# Ticket type, with current state
|
||||
if filter_type is None:
|
||||
if not filter_active:
|
||||
components.append("`{ticket_type}{ticket_state}`")
|
||||
else:
|
||||
components.append("`{ticket_type}`")
|
||||
elif not filter_active:
|
||||
components.append("`{ticket_real_state}`")
|
||||
if not filter_userid:
|
||||
# Ticket user
|
||||
components.append("<@{ticket.data.targetid}>")
|
||||
if filter_userid or (filter_active and filter_type):
|
||||
# Truncated ticket content
|
||||
components.append("{content}")
|
||||
|
||||
format_str = ' | '.join(components)
|
||||
|
||||
# Break tickets into blocks
|
||||
blocks = [tickets[i:i+10] for i in range(0, len(tickets), 10)]
|
||||
|
||||
# Build pages of tickets
|
||||
ticket_pages = []
|
||||
for block in blocks:
|
||||
ticket_page = []
|
||||
|
||||
type_len = max(len(type_formatted[ticket.type]) for ticket in block)
|
||||
state_len = max(len(state_formatted[ticket.state]) for ticket in block)
|
||||
for ticket in block:
|
||||
# First truncate content if required
|
||||
content = ticket.data.content
|
||||
if len(content) > 40:
|
||||
content = content[:37] + '...'
|
||||
|
||||
# Build ticket line
|
||||
line = format_str.format(
|
||||
ticket=ticket,
|
||||
timestamp=ticket.data.created_at.timestamp(),
|
||||
ticket_type=type_formatted[ticket.type],
|
||||
type_len=type_len,
|
||||
ticket_state=" [{}]".format(state_formatted[ticket.state]) if ticket.state != TicketState.OPEN else '',
|
||||
ticket_real_state=state_formatted[ticket.state],
|
||||
state_len=state_len,
|
||||
content=content
|
||||
)
|
||||
if ticket.state == TicketState.PARDONED:
|
||||
line = "~~{}~~".format(line)
|
||||
|
||||
# Add to current page
|
||||
ticket_page.append(line)
|
||||
# Combine lines and add page to pages
|
||||
ticket_pages.append('\n'.join(ticket_page))
|
||||
|
||||
# Build active ticket type summary
|
||||
freq = defaultdict(int)
|
||||
for ticket in tickets:
|
||||
if ticket.state != TicketState.PARDONED:
|
||||
freq[ticket.type] += 1
|
||||
summary_pairs = [
|
||||
(num, type_summary_formatted[ttype] + ('s' if num > 1 else ''))
|
||||
for ttype, num in freq.items()
|
||||
]
|
||||
summary_pairs.sort(key=lambda pair: pair[0])
|
||||
# num_len = max(len(str(num)) for num in freq.values())
|
||||
# type_summary = '\n'.join(
|
||||
# "**`{:<{}}`** {}".format(pair[0], num_len, pair[1])
|
||||
# for pair in summary_pairs
|
||||
# )
|
||||
|
||||
# # Build status summary
|
||||
# freq = defaultdict(int)
|
||||
# for ticket in tickets:
|
||||
# freq[ticket.state] += 1
|
||||
# num_len = max(len(str(num)) for num in freq.values())
|
||||
# status_summary = '\n'.join(
|
||||
# "**`{:<{}}`** {}".format(freq[state], num_len, state_str)
|
||||
# for state, state_str in state_summary_formatted.items()
|
||||
# if state in freq
|
||||
# )
|
||||
|
||||
summary_strings = [
|
||||
"**`{}`** {}".format(*pair) for pair in summary_pairs
|
||||
]
|
||||
if len(summary_strings) > 2:
|
||||
summary = ', '.join(summary_strings[:-1]) + ', and ' + summary_strings[-1]
|
||||
elif len(summary_strings) == 2:
|
||||
summary = ' and '.join(summary_strings)
|
||||
else:
|
||||
summary = ''.join(summary_strings)
|
||||
if summary:
|
||||
summary += '.'
|
||||
|
||||
# Build embed info
|
||||
title = "{}{}{}".format(
|
||||
"Active " if filter_active else '',
|
||||
"{} tickets ".format(type_formatted[filter_type]) if filter_type else "Tickets ",
|
||||
(" for {}".format(ctx.guild.get_member(filter_userid) or filter_userid)
|
||||
if filter_userid else " in {}".format(ctx.guild.name))
|
||||
)
|
||||
footer = "Click a ticket id to jump to it, or type the number to show the full ticket."
|
||||
page_count = len(blocks)
|
||||
if page_count > 1:
|
||||
footer += "\nPage {{page_num}}/{}".format(page_count)
|
||||
|
||||
# Create embeds
|
||||
embeds = [
|
||||
discord.Embed(
|
||||
title=title,
|
||||
description="{}\n{}".format(summary, page),
|
||||
colour=discord.Colour.orange(),
|
||||
).set_footer(text=footer.format(page_num=i+1))
|
||||
for i, page in enumerate(ticket_pages)
|
||||
]
|
||||
|
||||
# Run output with cancellation and listener
|
||||
out_msg = await ctx.pager(embeds, add_cancel=True)
|
||||
old_task = _displays.pop((ctx.ch.id, ctx.author.id), None)
|
||||
if old_task:
|
||||
old_task.cancel()
|
||||
_displays[(ctx.ch.id, ctx.author.id)] = display_task = asyncio.create_task(_ticket_display(ctx, ticket_map))
|
||||
ctx.tasks.append(display_task)
|
||||
await ctx.cancellable(out_msg, add_reaction=False)
|
||||
|
||||
|
||||
_displays = {} # (channelid, userid) -> Task
|
||||
async def _ticket_display(ctx, ticket_map):
|
||||
"""
|
||||
Display tickets when the ticket number is entered.
|
||||
"""
|
||||
current_ticket_msg = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for a number
|
||||
try:
|
||||
result = await ctx.client.wait_for(
|
||||
"message",
|
||||
check=lambda msg: (msg.author == ctx.author
|
||||
and msg.channel == ctx.ch
|
||||
and msg.content.isdigit()
|
||||
and int(msg.content) in ticket_map),
|
||||
timeout=60
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return
|
||||
|
||||
# Delete the response
|
||||
try:
|
||||
await result.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Display the ticket
|
||||
embed = ticket_map[int(result.content)].msg_args['embed']
|
||||
if current_ticket_msg:
|
||||
try:
|
||||
await current_ticket_msg.edit(embed=embed)
|
||||
except discord.HTTPException:
|
||||
current_ticket_msg = None
|
||||
|
||||
if not current_ticket_msg:
|
||||
try:
|
||||
current_ticket_msg = await ctx.reply(embed=embed)
|
||||
except discord.HTTPException:
|
||||
return
|
||||
asyncio.create_task(ctx.offer_delete(current_ticket_msg))
|
||||
except asyncio.CancelledError:
|
||||
if current_ticket_msg:
|
||||
try:
|
||||
await current_ticket_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"pardon",
|
||||
group="Moderation",
|
||||
desc="Pardon a ticket, or clear a member's moderation history.",
|
||||
flags=('type=',)
|
||||
)
|
||||
@guild_moderator()
|
||||
async def cmd_pardon(ctx, flags):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}pardon ticketid, ticketid, ticketid
|
||||
{prefix}pardon @user [--type <type>]
|
||||
Description:
|
||||
Marks the given tickets as no longer applicable.
|
||||
These tickets will not be considered when calculating automod actions such as automatic study bans.
|
||||
|
||||
This may be used to mark warns or other tickets as no longer in-effect.
|
||||
If the ticket is active when it is pardoned, it will be reverted, and any expiry cancelled.
|
||||
|
||||
Use the `{prefix}tickets` command to view the relevant tickets.
|
||||
Flags::
|
||||
type: Filter by ticket type. See **Ticket Types** in `{prefix}help tickets`.
|
||||
Examples:
|
||||
{prefix}pardon 21
|
||||
{prefix}pardon {ctx.guild.owner.mention} --type warn
|
||||
"""
|
||||
usage = "**Usage**: `{prefix}pardon ticketid` or `{prefix}pardon @user`.".format(prefix=ctx.best_prefix)
|
||||
if not ctx.args:
|
||||
return await ctx.error_reply(
|
||||
usage
|
||||
)
|
||||
|
||||
# Parse provided tickets or filters
|
||||
targetid = None
|
||||
ticketids = []
|
||||
args = {'guildid': ctx.guild.id}
|
||||
if ',' in ctx.args:
|
||||
# Assume provided numbers are ticketids.
|
||||
items = [item.strip() for item in ctx.args.split(',')]
|
||||
if not all(item.isdigit() for item in items):
|
||||
return await ctx.error_reply(usage)
|
||||
ticketids = [int(item) for item in items]
|
||||
args['guild_ticketid'] = ticketids
|
||||
else:
|
||||
# Guess whether the provided numbers were ticketids or not
|
||||
idstr = ctx.args.strip('<@!&> ')
|
||||
if not idstr.isdigit():
|
||||
return await ctx.error_reply(usage)
|
||||
|
||||
maybe_id = int(idstr)
|
||||
if maybe_id > 4194304: # Testing whether it is greater than the minimum snowflake id
|
||||
# Assume userid
|
||||
targetid = maybe_id
|
||||
args['targetid'] = maybe_id
|
||||
|
||||
# Add the type filter if provided
|
||||
if flags['type']:
|
||||
typestr = flags['type'].lower()
|
||||
if typestr not in type_accepts:
|
||||
return await ctx.error_reply(
|
||||
"Please see `{prefix}help tickets` for the valid ticket types!".format(prefix=ctx.best_prefix)
|
||||
)
|
||||
args['ticket_type'] = type_accepts[typestr]
|
||||
else:
|
||||
# Assume guild ticketid
|
||||
ticketids = [maybe_id]
|
||||
args['guild_ticketid'] = maybe_id
|
||||
|
||||
# Fetch the matching tickets
|
||||
tickets = Ticket.fetch_tickets(**args)
|
||||
|
||||
# Check whether we have the right selection of tickets
|
||||
if targetid and not tickets:
|
||||
return await ctx.error_reply(
|
||||
"<@{}> has no matching tickets to pardon!"
|
||||
)
|
||||
if ticketids and len(ticketids) != len(tickets):
|
||||
# Not all of the ticketids were valid
|
||||
difference = list(set(ticketids).difference(ticket.ticketid for ticket in tickets))
|
||||
if len(difference) == 1:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't find ticket `{}`!".format(difference[0])
|
||||
)
|
||||
else:
|
||||
return await ctx.error_reply(
|
||||
"Couldn't find any of the following tickets:\n`{}`".format(
|
||||
'`, `'.join(difference)
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether there are any tickets left to pardon
|
||||
to_pardon = [ticket for ticket in tickets if ticket.state != TicketState.PARDONED]
|
||||
if not to_pardon:
|
||||
if ticketids and len(tickets) == 1:
|
||||
ticket = tickets[0]
|
||||
return await ctx.error_reply(
|
||||
"[Ticket #{}]({}) is already pardoned!".format(ticket.data.guild_ticketid, ticket.link)
|
||||
)
|
||||
else:
|
||||
return await ctx.error_reply(
|
||||
"All of these tickets are already pardoned!"
|
||||
)
|
||||
|
||||
# We now know what tickets we want to pardon
|
||||
# Request the pardon reason
|
||||
try:
|
||||
reason = await ctx.input("Please provide a reason for the pardon.")
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Prompt timed out, no tickets were pardoned.")
|
||||
|
||||
# Pardon the tickets
|
||||
for ticket in to_pardon:
|
||||
await ticket.pardon(ctx.author, reason)
|
||||
|
||||
# Finally, ack the pardon
|
||||
if targetid:
|
||||
await ctx.embed_reply(
|
||||
"The active {}s for <@{}> have been cleared.".format(
|
||||
type_summary_formatted[args['ticket_type']] if flags['type'] else 'ticket',
|
||||
targetid
|
||||
)
|
||||
)
|
||||
elif len(to_pardon) == 1:
|
||||
ticket = to_pardon[0]
|
||||
await ctx.embed_reply(
|
||||
"[Ticket #{}]({}) was pardoned.".format(
|
||||
ticket.data.guild_ticketid,
|
||||
ticket.link
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.embed_reply(
|
||||
"The following tickets were pardoned.\n{}".format(
|
||||
", ".join(
|
||||
"[#{}]({})".format(ticket.data.guild_ticketid, ticket.link)
|
||||
for ticket in to_pardon
|
||||
)
|
||||
)
|
||||
)
|
||||
19
src/modules/pending-rewrite/moderation/data.py
Normal file
19
src/modules/pending-rewrite/moderation/data.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from data import Table, RowTable
|
||||
|
||||
|
||||
studyban_durations = Table('studyban_durations')
|
||||
|
||||
ticket_info = RowTable(
|
||||
'ticket_info',
|
||||
('ticketid', 'guild_ticketid',
|
||||
'guildid', 'targetid', 'ticket_type', 'ticket_state', 'moderator_id', 'auto',
|
||||
'log_msg_id', 'created_at',
|
||||
'content', 'context', 'addendum', 'duration',
|
||||
'file_name', 'file_data',
|
||||
'expiry',
|
||||
'pardoned_by', 'pardoned_at', 'pardoned_reason'),
|
||||
'ticketid',
|
||||
cache_size=20000
|
||||
)
|
||||
|
||||
tickets = Table('tickets')
|
||||
4
src/modules/pending-rewrite/moderation/module.py
Normal file
4
src/modules/pending-rewrite/moderation/module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from cmdClient import Module
|
||||
|
||||
|
||||
module = Module("Moderation")
|
||||
486
src/modules/pending-rewrite/moderation/tickets/Ticket.py
Normal file
486
src/modules/pending-rewrite/moderation/tickets/Ticket.py
Normal file
@@ -0,0 +1,486 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import datetime
|
||||
|
||||
import discord
|
||||
|
||||
from meta import client
|
||||
from data.conditions import THIS_SHARD
|
||||
from settings import GuildSettings
|
||||
from utils.lib import FieldEnum, strfdelta, utc_now
|
||||
|
||||
from .. import data
|
||||
from ..module import module
|
||||
|
||||
|
||||
class TicketType(FieldEnum):
|
||||
"""
|
||||
The possible ticket types.
|
||||
"""
|
||||
NOTE = 'NOTE', 'Note'
|
||||
WARNING = 'WARNING', 'Warning'
|
||||
STUDY_BAN = 'STUDY_BAN', 'Study Ban'
|
||||
MESAGE_CENSOR = 'MESSAGE_CENSOR', 'Message Censor'
|
||||
INVITE_CENSOR = 'INVITE_CENSOR', 'Invite Censor'
|
||||
|
||||
|
||||
class TicketState(FieldEnum):
|
||||
"""
|
||||
The possible ticket states.
|
||||
"""
|
||||
OPEN = 'OPEN', "Active"
|
||||
EXPIRING = 'EXPIRING', "Active"
|
||||
EXPIRED = 'EXPIRED', "Expired"
|
||||
PARDONED = 'PARDONED', "Pardoned"
|
||||
REVERTED = 'REVERTED', "Reverted"
|
||||
|
||||
|
||||
class Ticket:
|
||||
"""
|
||||
Abstract base class representing a Ticketed moderation action.
|
||||
"""
|
||||
# Type of event the class represents
|
||||
_ticket_type = None # type: TicketType
|
||||
|
||||
_ticket_types = {} # Map: TicketType -> Ticket subclass
|
||||
|
||||
_expiry_tasks = {} # Map: ticketid -> expiry Task
|
||||
|
||||
def __init__(self, ticketid, *args, **kwargs):
|
||||
self.ticketid = ticketid
|
||||
|
||||
@classmethod
|
||||
async def create(cls, *args, **kwargs):
|
||||
"""
|
||||
Method used to create a new ticket of the current type.
|
||||
Should add a row to the ticket table, post the ticket, and return the Ticket.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""
|
||||
Ticket row.
|
||||
This will usually be a row of `ticket_info`.
|
||||
"""
|
||||
return data.ticket_info.fetch(self.ticketid)
|
||||
|
||||
@property
|
||||
def guild(self):
|
||||
return client.get_guild(self.data.guildid)
|
||||
|
||||
@property
|
||||
def target(self):
|
||||
guild = self.guild
|
||||
return guild.get_member(self.data.targetid) if guild else None
|
||||
|
||||
@property
|
||||
def msg_args(self):
|
||||
"""
|
||||
Ticket message posted in the moderation log.
|
||||
"""
|
||||
args = {}
|
||||
|
||||
# Build embed
|
||||
info = self.data
|
||||
member = self.target
|
||||
name = str(member) if member else str(info.targetid)
|
||||
|
||||
if info.auto:
|
||||
title_fmt = "Ticket #{} | {} | {}[Auto] | {}"
|
||||
else:
|
||||
title_fmt = "Ticket #{} | {} | {} | {}"
|
||||
title = title_fmt.format(
|
||||
info.guild_ticketid,
|
||||
TicketState(info.ticket_state).desc,
|
||||
TicketType(info.ticket_type).desc,
|
||||
name
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=title,
|
||||
description=info.content,
|
||||
timestamp=info.created_at
|
||||
)
|
||||
embed.add_field(
|
||||
name="Target",
|
||||
value="<@{}>".format(info.targetid)
|
||||
)
|
||||
|
||||
if not info.auto:
|
||||
embed.add_field(
|
||||
name="Moderator",
|
||||
value="<@{}>".format(info.moderator_id)
|
||||
)
|
||||
|
||||
# if info.duration:
|
||||
# value = "`{}` {}".format(
|
||||
# strfdelta(datetime.timedelta(seconds=info.duration)),
|
||||
# "(Expiry <t:{:.0f}>)".format(info.expiry.timestamp()) if info.expiry else ""
|
||||
# )
|
||||
# embed.add_field(
|
||||
# name="Duration",
|
||||
# value=value
|
||||
# )
|
||||
if info.expiry:
|
||||
if info.ticket_state == TicketState.EXPIRING:
|
||||
embed.add_field(
|
||||
name="Expires at",
|
||||
value="<t:{:.0f}>\n(Duration: `{}`)".format(
|
||||
info.expiry.timestamp(),
|
||||
strfdelta(datetime.timedelta(seconds=info.duration))
|
||||
)
|
||||
)
|
||||
elif info.ticket_state == TicketState.EXPIRED:
|
||||
embed.add_field(
|
||||
name="Expired",
|
||||
value="<t:{:.0f}>".format(
|
||||
info.expiry.timestamp(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Expiry",
|
||||
value="<t:{:.0f}>".format(
|
||||
info.expiry.timestamp()
|
||||
)
|
||||
)
|
||||
|
||||
if info.context:
|
||||
embed.add_field(
|
||||
name="Context",
|
||||
value=info.context,
|
||||
inline=False
|
||||
)
|
||||
|
||||
if info.addendum:
|
||||
embed.add_field(
|
||||
name="Notes",
|
||||
value=info.addendum,
|
||||
inline=False
|
||||
)
|
||||
|
||||
if self.state == TicketState.PARDONED:
|
||||
embed.add_field(
|
||||
name="Pardoned",
|
||||
value=(
|
||||
"Pardoned by <@{}> at <t:{:.0f}>.\n{}"
|
||||
).format(
|
||||
info.pardoned_by,
|
||||
info.pardoned_at.timestamp(),
|
||||
info.pardoned_reason or ""
|
||||
),
|
||||
inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(text="ID: {}".format(info.targetid))
|
||||
|
||||
args['embed'] = embed
|
||||
|
||||
# Add file
|
||||
if info.file_name:
|
||||
args['file'] = discord.File(info.file_data, info.file_name)
|
||||
|
||||
return args
|
||||
|
||||
@property
|
||||
def link(self):
|
||||
"""
|
||||
The link to the ticket in the moderation log.
|
||||
"""
|
||||
info = self.data
|
||||
modlog = GuildSettings(info.guildid).mod_log.data
|
||||
|
||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||
info.guildid,
|
||||
modlog,
|
||||
info.log_msg_id
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return TicketState(self.data.ticket_state)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return TicketType(self.data.ticket_type)
|
||||
|
||||
async def update(self, **kwargs):
|
||||
"""
|
||||
Update ticket fields.
|
||||
"""
|
||||
fields = (
|
||||
'targetid', 'moderator_id', 'auto', 'log_msg_id',
|
||||
'content', 'expiry', 'ticket_state',
|
||||
'context', 'addendum', 'duration', 'file_name', 'file_data',
|
||||
'pardoned_by', 'pardoned_at', 'pardoned_reason',
|
||||
)
|
||||
params = {field: kwargs[field] for field in fields if field in kwargs}
|
||||
if params:
|
||||
data.ticket_info.update_where(params, ticketid=self.ticketid)
|
||||
|
||||
await self.update_expiry()
|
||||
await self.post()
|
||||
|
||||
async def post(self):
|
||||
"""
|
||||
Post or update the ticket in the moderation log.
|
||||
Also updates the saved message id.
|
||||
"""
|
||||
info = self.data
|
||||
modlog = GuildSettings(info.guildid).mod_log.value
|
||||
if not modlog:
|
||||
return
|
||||
|
||||
resend = True
|
||||
try:
|
||||
if info.log_msg_id:
|
||||
# Try to fetch the message
|
||||
message = await modlog.fetch_message(info.log_msg_id)
|
||||
if message:
|
||||
if message.author.id == client.user.id:
|
||||
# TODO: Handle file edit
|
||||
await message.edit(embed=self.msg_args['embed'])
|
||||
resend = False
|
||||
else:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
if resend:
|
||||
message = await modlog.send(**self.msg_args)
|
||||
self.data.log_msg_id = message.id
|
||||
except discord.HTTPException:
|
||||
client.log(
|
||||
"Cannot post ticket (tid: {}) due to discord exception or issue.".format(self.ticketid)
|
||||
)
|
||||
except Exception:
|
||||
# This should never happen in normal operation
|
||||
client.log(
|
||||
"Error while posting ticket (tid:{})! "
|
||||
"Exception traceback follows.\n{}".format(
|
||||
self.ticketid,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="TICKETS",
|
||||
level=logging.ERROR
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_expiring(cls):
|
||||
"""
|
||||
Load and schedule all expiring tickets.
|
||||
"""
|
||||
# TODO: Consider changing this to a flat timestamp system, to avoid storing lots of coroutines.
|
||||
# TODO: Consider only scheduling the expiries in the next day, and updating this once per day.
|
||||
# TODO: Only fetch tickets from guilds we are in.
|
||||
|
||||
# Cancel existing expiry tasks
|
||||
for task in cls._expiry_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Get all expiring tickets
|
||||
expiring_rows = data.tickets.select_where(
|
||||
ticket_state=TicketState.EXPIRING,
|
||||
guildid=THIS_SHARD
|
||||
)
|
||||
|
||||
# Create new expiry tasks
|
||||
now = utc_now()
|
||||
cls._expiry_tasks = {
|
||||
row['ticketid']: asyncio.create_task(
|
||||
cls._schedule_expiry_for(
|
||||
row['ticketid'],
|
||||
(row['expiry'] - now).total_seconds()
|
||||
)
|
||||
) for row in expiring_rows
|
||||
}
|
||||
|
||||
# Log
|
||||
client.log(
|
||||
"Loaded {} expiring tickets.".format(len(cls._expiry_tasks)),
|
||||
context="TICKET_LOADER",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _schedule_expiry_for(cls, ticketid, delay):
|
||||
"""
|
||||
Schedule expiry for a given ticketid
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
ticket = Ticket.fetch(ticketid)
|
||||
if ticket:
|
||||
await asyncio.shield(ticket._expire())
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
def update_expiry(self):
|
||||
# Cancel any existing expiry task
|
||||
task = self._expiry_tasks.pop(self.ticketid, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Schedule a new expiry task, if applicable
|
||||
if self.data.ticket_state == TicketState.EXPIRING:
|
||||
self._expiry_tasks[self.ticketid] = asyncio.create_task(
|
||||
self._schedule_expiry_for(
|
||||
self.ticketid,
|
||||
(self.data.expiry - utc_now()).total_seconds()
|
||||
)
|
||||
)
|
||||
|
||||
async def cancel_expiry(self):
|
||||
"""
|
||||
Cancel ticket expiry.
|
||||
|
||||
In particular, may be used if another ticket overrides `self`.
|
||||
Sets the ticket state to `OPEN`, so that it no longer expires.
|
||||
"""
|
||||
if self.state == TicketState.EXPIRING:
|
||||
# Update the ticket state
|
||||
self.data.ticket_state = TicketState.OPEN
|
||||
|
||||
# Remove from expiry tsks
|
||||
self.update_expiry()
|
||||
|
||||
# Repost
|
||||
await self.post()
|
||||
|
||||
async def _revert(self, reason=None):
|
||||
"""
|
||||
Method used to revert the ticket action, e.g. unban or remove mute role.
|
||||
Generally called by `pardon` and `_expire`.
|
||||
|
||||
May be overriden by the Ticket type, if they implement any revert logic.
|
||||
Is a no-op by default.
|
||||
"""
|
||||
return
|
||||
|
||||
async def _expire(self):
|
||||
"""
|
||||
Method to automatically expire a ticket.
|
||||
|
||||
May be overriden by the Ticket type for more complex expiry logic.
|
||||
Must set `data.ticket_state` to `EXPIRED` if applicable.
|
||||
"""
|
||||
if self.state == TicketState.EXPIRING:
|
||||
client.log(
|
||||
"Automatically expiring ticket (tid:{}).".format(self.ticketid),
|
||||
context="TICKETS"
|
||||
)
|
||||
try:
|
||||
await self._revert(reason="Automatic Expiry")
|
||||
except Exception:
|
||||
# This should never happen in normal operation
|
||||
client.log(
|
||||
"Error while expiring ticket (tid:{})! "
|
||||
"Exception traceback follows.\n{}".format(
|
||||
self.ticketid,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="TICKETS",
|
||||
level=logging.ERROR
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.data.ticket_state = TicketState.EXPIRED
|
||||
|
||||
# Update log message
|
||||
await self.post()
|
||||
|
||||
# Post a note to the modlog
|
||||
modlog = GuildSettings(self.data.guildid).mod_log.value
|
||||
if modlog:
|
||||
try:
|
||||
await modlog.send(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
description="[Ticket #{}]({}) expired!".format(self.data.guild_ticketid, self.link)
|
||||
)
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
async def pardon(self, moderator, reason, timestamp=None):
|
||||
"""
|
||||
Pardon process for the ticket.
|
||||
|
||||
May be overidden by the Ticket type for more complex pardon logic.
|
||||
Must set `data.ticket_state` to `PARDONED` if applicable.
|
||||
"""
|
||||
if self.state != TicketState.PARDONED:
|
||||
if self.state in (TicketState.OPEN, TicketState.EXPIRING):
|
||||
try:
|
||||
await self._revert(reason="Pardoned by {}".format(moderator.id))
|
||||
except Exception:
|
||||
# This should never happen in normal operation
|
||||
client.log(
|
||||
"Error while pardoning ticket (tid:{})! "
|
||||
"Exception traceback follows.\n{}".format(
|
||||
self.ticketid,
|
||||
traceback.format_exc()
|
||||
),
|
||||
context="TICKETS",
|
||||
level=logging.ERROR
|
||||
)
|
||||
|
||||
# Update state
|
||||
with self.data.batch_update():
|
||||
self.data.ticket_state = TicketState.PARDONED
|
||||
self.data.pardoned_at = utc_now()
|
||||
self.data.pardoned_by = moderator.id
|
||||
self.data.pardoned_reason = reason
|
||||
|
||||
# Update (i.e. remove) expiry
|
||||
self.update_expiry()
|
||||
|
||||
# Update log message
|
||||
await self.post()
|
||||
|
||||
@classmethod
|
||||
def fetch_tickets(cls, *ticketids, **kwargs):
|
||||
"""
|
||||
Fetch tickets matching the given criteria (passed transparently to `select_where`).
|
||||
Positional arguments are treated as `ticketids`, which are not supported in keyword arguments.
|
||||
"""
|
||||
if ticketids:
|
||||
kwargs['ticketid'] = ticketids
|
||||
|
||||
# Set the ticket type to the class type if not specified
|
||||
if cls._ticket_type and 'ticket_type' not in kwargs:
|
||||
kwargs['ticket_type'] = cls._ticket_type
|
||||
|
||||
# This is actually mainly for caching, since we don't pass the data to the initialiser
|
||||
rows = data.ticket_info.fetch_rows_where(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return [
|
||||
cls._ticket_types[TicketType(row.ticket_type)](row.ticketid)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, ticketid):
|
||||
"""
|
||||
Return the Ticket with the given id, if found, or `None` otherwise.
|
||||
"""
|
||||
tickets = cls.fetch_tickets(ticketid)
|
||||
return tickets[0] if tickets else None
|
||||
|
||||
@classmethod
|
||||
def register_ticket_type(cls, ticket_cls):
|
||||
"""
|
||||
Decorator to register a new Ticket subclass as a ticket type.
|
||||
"""
|
||||
cls._ticket_types[ticket_cls._ticket_type] = ticket_cls
|
||||
return ticket_cls
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def load_expiring_tickets(client):
|
||||
Ticket.load_expiring()
|
||||
@@ -0,0 +1,4 @@
|
||||
from .Ticket import Ticket, TicketType, TicketState
|
||||
from .studybans import StudyBanTicket
|
||||
from .notes import NoteTicket
|
||||
from .warns import WarnTicket
|
||||
112
src/modules/pending-rewrite/moderation/tickets/notes.py
Normal file
112
src/modules/pending-rewrite/moderation/tickets/notes.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Note ticket implementation.
|
||||
|
||||
Guild moderators can add a note about a user, visible in their moderation history.
|
||||
Notes appear in the moderation log and the user's ticket history, like any other ticket.
|
||||
|
||||
This module implements the Note TicketType and the `note` moderation command.
|
||||
"""
|
||||
from cmdClient.lib import ResponseTimedOut
|
||||
|
||||
from wards import guild_moderator
|
||||
|
||||
from ..module import module
|
||||
from ..data import tickets
|
||||
|
||||
from .Ticket import Ticket, TicketType, TicketState
|
||||
|
||||
|
||||
@Ticket.register_ticket_type
|
||||
class NoteTicket(Ticket):
|
||||
_ticket_type = TicketType.NOTE
|
||||
|
||||
@classmethod
|
||||
async def create(cls, guildid, targetid, moderatorid, content, **kwargs):
|
||||
"""
|
||||
Create a new Note on a target.
|
||||
|
||||
`kwargs` are passed transparently to the table insert method.
|
||||
"""
|
||||
ticket_row = tickets.insert(
|
||||
guildid=guildid,
|
||||
targetid=targetid,
|
||||
ticket_type=cls._ticket_type,
|
||||
ticket_state=TicketState.OPEN,
|
||||
moderator_id=moderatorid,
|
||||
auto=False,
|
||||
content=content,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create the note ticket
|
||||
ticket = cls(ticket_row['ticketid'])
|
||||
|
||||
# Post the ticket and return
|
||||
await ticket.post()
|
||||
return ticket
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"note",
|
||||
group="Moderation",
|
||||
desc="Add a Note to a member's record."
|
||||
)
|
||||
@guild_moderator()
|
||||
async def cmd_note(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}note @target
|
||||
{prefix}note @target <content>
|
||||
Description:
|
||||
Add a note to the target's moderation record.
|
||||
The note will appear in the moderation log and in the `tickets` command.
|
||||
|
||||
The `target` must be specificed by mention or user id.
|
||||
If the `content` is not given, it will be prompted for.
|
||||
Example:
|
||||
{prefix}note {ctx.author.mention} Seen reading the `note` documentation.
|
||||
"""
|
||||
if not ctx.args:
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{}note @target <content>`.".format(ctx.best_prefix)
|
||||
)
|
||||
|
||||
# Extract the target. We don't require them to be in the server
|
||||
splits = ctx.args.split(maxsplit=1)
|
||||
target_str = splits[0].strip('<@!&> ')
|
||||
if not target_str.isdigit():
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{}note @target <content>`.\n"
|
||||
"`target` must be provided by mention or userid.".format(ctx.best_prefix)
|
||||
)
|
||||
targetid = int(target_str)
|
||||
|
||||
# Extract or prompt for the content
|
||||
if len(splits) != 2:
|
||||
try:
|
||||
content = await ctx.input("What note would you like to add?", timeout=300)
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Prompt timed out, no note was created.")
|
||||
else:
|
||||
content = splits[1].strip()
|
||||
|
||||
# Create the note ticket
|
||||
ticket = await NoteTicket.create(
|
||||
ctx.guild.id,
|
||||
targetid,
|
||||
ctx.author.id,
|
||||
content
|
||||
)
|
||||
|
||||
if ticket.data.log_msg_id:
|
||||
await ctx.embed_reply(
|
||||
"Note on <@{}> created as [Ticket #{}]({}).".format(
|
||||
targetid,
|
||||
ticket.data.guild_ticketid,
|
||||
ticket.link
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.embed_reply(
|
||||
"Note on <@{}> created as Ticket #{}.".format(targetid, ticket.data.guild_ticketid)
|
||||
)
|
||||
126
src/modules/pending-rewrite/moderation/tickets/studybans.py
Normal file
126
src/modules/pending-rewrite/moderation/tickets/studybans.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import datetime
|
||||
import discord
|
||||
|
||||
from meta import client
|
||||
from utils.lib import utc_now
|
||||
from settings import GuildSettings
|
||||
from data import NOT
|
||||
|
||||
from .. import data
|
||||
from .Ticket import Ticket, TicketType, TicketState
|
||||
|
||||
|
||||
@Ticket.register_ticket_type
|
||||
class StudyBanTicket(Ticket):
|
||||
_ticket_type = TicketType.STUDY_BAN
|
||||
|
||||
@classmethod
|
||||
async def create(cls, guildid, targetid, moderatorid, reason, expiry=None, **kwargs):
|
||||
"""
|
||||
Create a new study ban ticket.
|
||||
"""
|
||||
# First create the ticket itself
|
||||
ticket_row = data.tickets.insert(
|
||||
guildid=guildid,
|
||||
targetid=targetid,
|
||||
ticket_type=cls._ticket_type,
|
||||
ticket_state=TicketState.EXPIRING if expiry else TicketState.OPEN,
|
||||
moderator_id=moderatorid,
|
||||
auto=(moderatorid == client.user.id),
|
||||
content=reason,
|
||||
expiry=expiry,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create the Ticket
|
||||
ticket = cls(ticket_row['ticketid'])
|
||||
|
||||
# Schedule ticket expiry, if applicable
|
||||
if expiry:
|
||||
ticket.update_expiry()
|
||||
|
||||
# Cancel any existing studyban expiry for this member
|
||||
tickets = cls.fetch_tickets(
|
||||
guildid=guildid,
|
||||
ticketid=NOT(ticket_row['ticketid']),
|
||||
targetid=targetid,
|
||||
ticket_state=TicketState.EXPIRING
|
||||
)
|
||||
for ticket in tickets:
|
||||
await ticket.cancel_expiry()
|
||||
|
||||
# Post the ticket
|
||||
await ticket.post()
|
||||
|
||||
# Return the ticket
|
||||
return ticket
|
||||
|
||||
async def _revert(self, reason=None):
|
||||
"""
|
||||
Revert the studyban by removing the role.
|
||||
"""
|
||||
guild_settings = GuildSettings(self.data.guildid)
|
||||
role = guild_settings.studyban_role.value
|
||||
target = self.target
|
||||
|
||||
if target and role:
|
||||
try:
|
||||
await target.remove_roles(
|
||||
role,
|
||||
reason="Reverting StudyBan: {}".format(reason)
|
||||
)
|
||||
except discord.HTTPException:
|
||||
# TODO: Error log?
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def autoban(cls, guild, target, reason, **kwargs):
|
||||
"""
|
||||
Convenience method to automatically studyban a member, for the configured duration.
|
||||
If the role is set, this will create and return a `StudyBanTicket` regardless of whether the
|
||||
studyban was successful.
|
||||
If the role is not set, or the ticket cannot be created, this will return `None`.
|
||||
"""
|
||||
# Get the studyban role, fail if there isn't one set, or the role doesn't exist
|
||||
guild_settings = GuildSettings(guild.id)
|
||||
role = guild_settings.studyban_role.value
|
||||
if not role:
|
||||
return None
|
||||
|
||||
# Attempt to add the role, record failure
|
||||
try:
|
||||
await target.add_roles(role, reason="Applying StudyBan: {}".format(reason[:400]))
|
||||
except discord.HTTPException:
|
||||
role_failed = True
|
||||
else:
|
||||
role_failed = False
|
||||
|
||||
# Calculate the applicable automatic duration and expiry
|
||||
# First count the existing non-pardoned studybans for this target
|
||||
studyban_count = data.tickets.select_one_where(
|
||||
guildid=guild.id,
|
||||
targetid=target.id,
|
||||
ticket_type=cls._ticket_type,
|
||||
ticket_state=NOT(TicketState.PARDONED),
|
||||
select_columns=('COUNT(*)',)
|
||||
)[0]
|
||||
studyban_count = int(studyban_count)
|
||||
|
||||
# Then read the guild setting to find the applicable duration
|
||||
studyban_durations = guild_settings.studyban_durations.value
|
||||
if studyban_count < len(studyban_durations):
|
||||
duration = studyban_durations[studyban_count]
|
||||
expiry = utc_now() + datetime.timedelta(seconds=duration)
|
||||
else:
|
||||
duration = None
|
||||
expiry = None
|
||||
|
||||
# Create the ticket and return
|
||||
if role_failed:
|
||||
kwargs['addendum'] = '\n'.join((
|
||||
kwargs.get('addendum', ''),
|
||||
"Could not add the studyban role! Please add the role manually and check my permissions."
|
||||
))
|
||||
return await cls.create(
|
||||
guild.id, target.id, client.user.id, reason, duration=duration, expiry=expiry, **kwargs
|
||||
)
|
||||
153
src/modules/pending-rewrite/moderation/tickets/warns.py
Normal file
153
src/modules/pending-rewrite/moderation/tickets/warns.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Warn ticket implementation.
|
||||
|
||||
Guild moderators can officially warn a user via command.
|
||||
This DMs the users with the warning.
|
||||
"""
|
||||
import datetime
|
||||
import discord
|
||||
from cmdClient.lib import ResponseTimedOut
|
||||
|
||||
from wards import guild_moderator
|
||||
|
||||
from ..module import module
|
||||
from ..data import tickets
|
||||
|
||||
from .Ticket import Ticket, TicketType, TicketState
|
||||
|
||||
|
||||
@Ticket.register_ticket_type
|
||||
class WarnTicket(Ticket):
|
||||
_ticket_type = TicketType.WARNING
|
||||
|
||||
@classmethod
|
||||
async def create(cls, guildid, targetid, moderatorid, content, **kwargs):
|
||||
"""
|
||||
Create a new Warning for the target.
|
||||
|
||||
`kwargs` are passed transparently to the table insert method.
|
||||
"""
|
||||
ticket_row = tickets.insert(
|
||||
guildid=guildid,
|
||||
targetid=targetid,
|
||||
ticket_type=cls._ticket_type,
|
||||
ticket_state=TicketState.OPEN,
|
||||
moderator_id=moderatorid,
|
||||
content=content,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create the note ticket
|
||||
ticket = cls(ticket_row['ticketid'])
|
||||
|
||||
# Post the ticket and return
|
||||
await ticket.post()
|
||||
return ticket
|
||||
|
||||
async def _revert(*args, **kwargs):
|
||||
# Warnings don't have a revert process
|
||||
pass
|
||||
|
||||
|
||||
@module.cmd(
|
||||
"warn",
|
||||
group="Moderation",
|
||||
desc="Officially warn a user for a misbehaviour."
|
||||
)
|
||||
@guild_moderator()
|
||||
async def cmd_warn(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}warn @target
|
||||
{prefix}warn @target <reason>
|
||||
Description:
|
||||
|
||||
The `target` must be specificed by mention or user id.
|
||||
If the `reason` is not given, it will be prompted for.
|
||||
Example:
|
||||
{prefix}warn {ctx.author.mention} Don't actually read the documentation!
|
||||
"""
|
||||
if not ctx.args:
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{}warn @target <reason>`.".format(ctx.best_prefix)
|
||||
)
|
||||
|
||||
# Extract the target. We do require them to be in the server
|
||||
splits = ctx.args.split(maxsplit=1)
|
||||
target_str = splits[0].strip('<@!&> ')
|
||||
if not target_str.isdigit():
|
||||
return await ctx.error_reply(
|
||||
"**Usage:** `{}warn @target <reason>`.\n"
|
||||
"`target` must be provided by mention or userid.".format(ctx.best_prefix)
|
||||
)
|
||||
targetid = int(target_str)
|
||||
target = ctx.guild.get_member(targetid)
|
||||
if not target:
|
||||
return await ctx.error_reply("Cannot warn a user who is not in the server!")
|
||||
|
||||
# Extract or prompt for the content
|
||||
if len(splits) != 2:
|
||||
try:
|
||||
content = await ctx.input("Please give a reason for this warning!", timeout=300)
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Prompt timed out, the member was not warned.")
|
||||
else:
|
||||
content = splits[1].strip()
|
||||
|
||||
# Create the warn ticket
|
||||
ticket = await WarnTicket.create(
|
||||
ctx.guild.id,
|
||||
targetid,
|
||||
ctx.author.id,
|
||||
content
|
||||
)
|
||||
|
||||
# Attempt to message the member
|
||||
embed = discord.Embed(
|
||||
title="You have received a warning!",
|
||||
description=(
|
||||
content
|
||||
),
|
||||
colour=discord.Colour.red(),
|
||||
timestamp=datetime.datetime.utcnow()
|
||||
)
|
||||
embed.add_field(
|
||||
name="Info",
|
||||
value=(
|
||||
"*Warnings appear in your moderation history. "
|
||||
"Failure to comply, or repeated warnings, "
|
||||
"may result in muting, studybanning, or server banning.*"
|
||||
)
|
||||
)
|
||||
embed.set_footer(
|
||||
icon_url=ctx.guild.icon_url,
|
||||
text=ctx.guild.name
|
||||
)
|
||||
dm_msg = None
|
||||
try:
|
||||
dm_msg = await target.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Get previous warnings
|
||||
count = tickets.select_one_where(
|
||||
guildid=ctx.guild.id,
|
||||
targetid=targetid,
|
||||
ticket_type=TicketType.WARNING,
|
||||
ticket_state=[TicketState.OPEN, TicketState.EXPIRING],
|
||||
select_columns=('COUNT(*)',)
|
||||
)[0]
|
||||
if count == 1:
|
||||
prev_str = "This is their first warning."
|
||||
else:
|
||||
prev_str = "They now have `{}` warnings.".format(count)
|
||||
|
||||
await ctx.embed_reply(
|
||||
"[Ticket #{}]({}): {} has been warned. {}\n{}".format(
|
||||
ticket.data.guild_ticketid,
|
||||
ticket.link,
|
||||
target.mention,
|
||||
prev_str,
|
||||
"*Could not DM the user their warning!*" if not dm_msg else ''
|
||||
)
|
||||
)
|
||||
4
src/modules/pending-rewrite/moderation/video/__init__.py
Normal file
4
src/modules/pending-rewrite/moderation/video/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from . import data
|
||||
from . import admin
|
||||
|
||||
from . import watchdog
|
||||
128
src/modules/pending-rewrite/moderation/video/admin.py
Normal file
128
src/modules/pending-rewrite/moderation/video/admin.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from settings import GuildSettings, GuildSetting
|
||||
from wards import guild_admin
|
||||
|
||||
import settings
|
||||
|
||||
from .data import video_channels
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class video_channels(settings.ChannelList, settings.ListData, settings.Setting):
|
||||
category = "Video Channels"
|
||||
|
||||
attr_name = 'video_channels'
|
||||
|
||||
_table_interface = video_channels
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'channelid'
|
||||
_setting = settings.VoiceChannel
|
||||
|
||||
write_ward = guild_admin
|
||||
display_name = "video_channels"
|
||||
desc = "Channels where members are required to enable their video."
|
||||
|
||||
_force_unique = True
|
||||
|
||||
long_desc = (
|
||||
"Members must keep their video enabled in these channels.\n"
|
||||
"If they do not keep their video enabled, they will be asked to enable it in their DMS after `15` seconds, "
|
||||
"and then kicked from the channel with another warning after the `video_grace_period` duration has passed.\n"
|
||||
"After the first offence, if the `video_studyban` is enabled and the `studyban_role` is set, "
|
||||
"they will also be automatically studybanned."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire objects
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members must enable their video in the following channels:\n{}".format(self.formatted)
|
||||
else:
|
||||
return "There are no video-required channels set up."
|
||||
|
||||
@classmethod
|
||||
async def launch_task(cls, client):
|
||||
"""
|
||||
Launch initialisation step for the `video_channels` setting.
|
||||
|
||||
Pre-fill cache for the guilds with currently active voice channels.
|
||||
"""
|
||||
active_guildids = [
|
||||
guild.id
|
||||
for guild in client.guilds
|
||||
if any(channel.members for channel in guild.voice_channels)
|
||||
]
|
||||
if active_guildids:
|
||||
cache = {guildid: [] for guildid in active_guildids}
|
||||
rows = cls._table_interface.select_where(
|
||||
guildid=active_guildids
|
||||
)
|
||||
for row in rows:
|
||||
cache[row['guildid']].append(row['channelid'])
|
||||
cls._cache.update(cache)
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class video_studyban(settings.Boolean, GuildSetting):
|
||||
category = "Video Channels"
|
||||
|
||||
attr_name = 'video_studyban'
|
||||
_data_column = 'video_studyban'
|
||||
|
||||
display_name = "video_studyban"
|
||||
desc = "Whether to studyban members if they don't enable their video."
|
||||
|
||||
long_desc = (
|
||||
"If enabled, members who do not enable their video in the configured `video_channels` will be "
|
||||
"study-banned after a single warning.\n"
|
||||
"When disabled, members will only be warned and removed from the channel."
|
||||
)
|
||||
|
||||
_default = True
|
||||
_outputs = {True: "Enabled", False: "Disabled"}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "Members will now be study-banned if they don't enable their video in the configured video channels."
|
||||
else:
|
||||
return "Members will not be study-banned if they don't enable their video in video channels."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class video_grace_period(settings.Duration, GuildSetting):
|
||||
category = "Video Channels"
|
||||
|
||||
attr_name = 'video_grace_period'
|
||||
_data_column = 'video_grace_period'
|
||||
|
||||
display_name = "video_grace_period"
|
||||
desc = "How long to wait before kicking/studybanning members who don't enable their video."
|
||||
|
||||
long_desc = (
|
||||
"The period after a member has been asked to enable their video in a video-only channel "
|
||||
"before they will be kicked from the channel, and warned or studybanned (if enabled)."
|
||||
)
|
||||
|
||||
_default = 90
|
||||
_default_multiplier = 1
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id: int, data, **kwargs):
|
||||
"""
|
||||
Return the string version of the data.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
else:
|
||||
return "`{} seconds`".format(data)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
return (
|
||||
"Members who do not enable their video will "
|
||||
"be disconnected after {}.".format(self.formatted)
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user