rewrite: New bot framework.
This commit is contained in:
40
bot/meta/LionBot.py
Normal file
40
bot/meta/LionBot.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
|
||||
from .config import Conf
|
||||
|
||||
|
||||
class LionBot(commands.Bot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
appname: str,
|
||||
db: Database,
|
||||
config: Conf,
|
||||
initial_extensions: List[str],
|
||||
web_client: ClientSession,
|
||||
testing_guilds: List[int] = [],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.web_client = web_client
|
||||
self.testing_guilds = testing_guilds
|
||||
self.initial_extensions = initial_extensions
|
||||
self.db = db
|
||||
self.appname = appname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(extension)
|
||||
|
||||
for guildid in self.testing_guilds:
|
||||
guild = discord.Object(guildid)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
@@ -1,9 +1,5 @@
|
||||
from .logger import log, logger
|
||||
|
||||
from . import interactions
|
||||
from . import patches
|
||||
|
||||
from .client import client
|
||||
from .LionBot import LionBot
|
||||
from .config import conf
|
||||
from .args import args
|
||||
from . import sharding
|
||||
from . import logger
|
||||
|
||||
@@ -27,7 +27,7 @@ class configEmoji(PartialEmoji):
|
||||
animated, name, id = emojistr.split(':')
|
||||
return cls(
|
||||
name=name,
|
||||
fallback=PartialEmoji(name=fallback),
|
||||
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||
animated=bool(animated),
|
||||
id=int(id)
|
||||
)
|
||||
@@ -60,11 +60,26 @@ class MapDotProxy:
|
||||
return self._map.__getitem__(key)
|
||||
|
||||
|
||||
class ConfigParser(cfgp.ConfigParser):
|
||||
"""
|
||||
Extension of base ConfigParser allowing optional
|
||||
section option retrieval without defaults.
|
||||
"""
|
||||
def options(self, section, no_defaults=False, **kwargs):
|
||||
if no_defaults:
|
||||
try:
|
||||
return list(self._sections[section].keys())
|
||||
except KeyError:
|
||||
raise cfgp.NoSectionError(section)
|
||||
else:
|
||||
return super().options(section, **kwargs)
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = cfgp.ConfigParser(
|
||||
self.config = ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
@@ -102,7 +117,7 @@ class Conf:
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
return self.config[section]
|
||||
return self.config[section.upper()]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
@@ -119,4 +134,4 @@ class Conf:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config)
|
||||
conf = Conf(args.config, 'STUDYLION')
|
||||
146
bot/meta/logger.py
Normal file
146
bot/meta/logger.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from contextvars import ContextVar
|
||||
from discord import AllowedMentions
|
||||
|
||||
from .config import conf
|
||||
from . import sharding
|
||||
|
||||
|
||||
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||
log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION')
|
||||
|
||||
|
||||
RESET_SEQ = "\033[0m"
|
||||
COLOR_SEQ = "\033[3%dm"
|
||||
BOLD_SEQ = "\033[1m"
|
||||
"]]]"
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||
|
||||
def colour_escape(fmt: str) -> str:
|
||||
cmap = {
|
||||
'%(black)': COLOR_SEQ % BLACK,
|
||||
'%(red)': COLOR_SEQ % RED,
|
||||
'%(green)': COLOR_SEQ % GREEN,
|
||||
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||
'%(blue)': COLOR_SEQ % BLUE,
|
||||
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||
'%(cyan)': COLOR_SEQ % CYAN,
|
||||
'%(white)': COLOR_SEQ % WHITE,
|
||||
'%(reset)': RESET_SEQ,
|
||||
'%(bold)': BOLD_SEQ,
|
||||
}
|
||||
for key, value in cmap.items():
|
||||
fmt = fmt.replace(key, value)
|
||||
return fmt
|
||||
|
||||
|
||||
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
|
||||
'[%(cyan)SHARD {:02}%(reset)]'.format(sharding.shard_number) +
|
||||
'[%(cyan)%(context)-22s%(reset)]' +
|
||||
'[%(cyan)%(action)-22s%(reset)]' +
|
||||
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||
' %(white)%(message)s%(reset)')
|
||||
log_format = colour_escape(log_format)
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(
|
||||
fmt=log_format,
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger.setLevel(logging.NOTSET)
|
||||
|
||||
|
||||
class LessThanFilter(logging.Filter):
|
||||
def __init__(self, exclusive_maximum, name=""):
|
||||
super(LessThanFilter, self).__init__(name)
|
||||
self.max_level = exclusive_maximum
|
||||
|
||||
def filter(self, record):
|
||||
# non-zero return means we log this message
|
||||
return 1 if record.levelno < self.max_level else 0
|
||||
|
||||
|
||||
class ContextInjection(logging.Filter):
|
||||
def filter(self, record):
|
||||
if not hasattr(record, 'context'):
|
||||
record.context = log_context.get()
|
||||
if not hasattr(record, 'action'):
|
||||
record.action = log_action.get()
|
||||
return True
|
||||
|
||||
|
||||
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||
logging_handler_out.setLevel(logging.DEBUG)
|
||||
logging_handler_out.setFormatter(log_fmt)
|
||||
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||
logging_handler_out.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_out)
|
||||
|
||||
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||
logging_handler_err.setLevel(logging.WARNING)
|
||||
logging_handler_err.setFormatter(log_fmt)
|
||||
logging_handler_err.addFilter(ContextInjection())
|
||||
logger.addHandler(logging_handler_err)
|
||||
|
||||
# TODO: Add an async handler for posting
|
||||
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
|
||||
# Then we can handle error channels etc differently
|
||||
# The formatting can be handled with a custom handler as well
|
||||
|
||||
|
||||
# Define the context log format and attach it to the command logger as well
|
||||
def log(message, context="GLOBAL", level=logging.INFO, post=True):
|
||||
# Add prefixes to lines for better parsing capability
|
||||
lines = message.splitlines()
|
||||
if len(lines) > 1:
|
||||
lines = [
|
||||
'┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line
|
||||
for i, line in enumerate(lines)
|
||||
]
|
||||
else:
|
||||
lines = ['─ ' + message]
|
||||
|
||||
for line in lines:
|
||||
logger.log(level, '\b[{}] {}'.format(
|
||||
str(context).center(22, '='),
|
||||
line
|
||||
))
|
||||
|
||||
# Fire and forget to the channel logger, if it is set up
|
||||
if post and client.is_ready():
|
||||
asyncio.ensure_future(live_log(message, context, level))
|
||||
|
||||
|
||||
# Live logger that posts to the logging channels
|
||||
async def live_log(message, context, level):
|
||||
if level >= logging.INFO:
|
||||
if level >= logging.WARNING:
|
||||
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
|
||||
else:
|
||||
log_chid = conf.bot.getint('log_channel')
|
||||
|
||||
# Generate the log messages
|
||||
if sharding.sharded:
|
||||
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
|
||||
else:
|
||||
header = f"[{logging.getLevelName(level)}][{context}]"
|
||||
|
||||
if len(message) > 1900:
|
||||
blocks = split_text(message, blocksize=1900, code=False)
|
||||
else:
|
||||
blocks = [message]
|
||||
|
||||
if len(blocks) > 1:
|
||||
blocks = [
|
||||
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
|
||||
]
|
||||
else:
|
||||
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
|
||||
|
||||
# Post the log messages
|
||||
if log_chid:
|
||||
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
|
||||
@@ -1,9 +1,35 @@
|
||||
from .args import args
|
||||
from .config import conf
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
|
||||
|
||||
shard_number = args.shard or 0
|
||||
|
||||
shard_count = conf.bot.getint('shard_count', 1)
|
||||
|
||||
sharded = (shard_count > 0)
|
||||
|
||||
|
||||
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by shard id.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(_shard_condition('guildid', 10, 1))
|
||||
"""
|
||||
return Condition(
|
||||
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
shard_count=sql.Literal(shard_count)
|
||||
),
|
||||
Joiner.EQUALS,
|
||||
sql.Placeholder(),
|
||||
(shard_id,)
|
||||
)
|
||||
|
||||
|
||||
# Pre-built Condition for filtering by current shard.
|
||||
THIS_SHARD = SHARDID(shard_number)
|
||||
|
||||
Reference in New Issue
Block a user