rewrite: New bot framework.

This commit is contained in:
2022-11-02 07:24:57 +02:00
parent 069c032e02
commit b27ee447b3
8 changed files with 303 additions and 591 deletions

40
bot/meta/LionBot.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import List, Optional, Dict
import discord
from discord.ext import commands
from aiohttp import ClientSession
from data import Database
from .config import Conf
class LionBot(commands.Bot):
def __init__(
self,
*args,
appname: str,
db: Database,
config: Conf,
initial_extensions: List[str],
web_client: ClientSession,
testing_guilds: List[int] = [],
**kwargs,
):
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
# self.appdata = appdata
self.config = config
async def setup_hook(self) -> None:
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)

View File

@@ -1,9 +1,5 @@
from .logger import log, logger
from . import interactions
from . import patches
from .client import client
from .LionBot import LionBot
from .config import conf
from .args import args
from . import sharding
from . import logger

View File

@@ -27,7 +27,7 @@ class configEmoji(PartialEmoji):
animated, name, id = emojistr.split(':')
return cls(
name=name,
fallback=PartialEmoji(name=fallback),
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated),
id=int(id)
)
@@ -60,11 +60,26 @@ class MapDotProxy:
return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = cfgp.ConfigParser(
self.config = ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
@@ -102,7 +117,7 @@ class Conf:
return self.section[key].strip()
def __getattr__(self, section):
return self.config[section]
return self.config[section.upper()]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
@@ -119,4 +134,4 @@ class Conf:
self.config.write(conffile)
conf = Conf(args.config)
conf = Conf(args.config, 'STUDYLION')

146
bot/meta/logger.py Normal file
View File

@@ -0,0 +1,146 @@
import sys
import logging
import asyncio
from contextvars import ContextVar
from discord import AllowedMentions
from .config import conf
from . import sharding
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION')
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
'[%(cyan)SHARD {:02}%(reset)]'.format(sharding.shard_number) +
'[%(cyan)%(context)-22s%(reset)]' +
'[%(cyan)%(action)-22s%(reset)]' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ContextInjection(logging.Filter):
def filter(self, record):
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'action'):
record.action = log_action.get()
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
# TODO: Add an async handler for posting
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
# Then we can handle error channels etc differently
# The formatting can be handled with a custom handler as well
# Define the context log format and attach it to the command logger as well
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]

View File

@@ -1,9 +1,35 @@
from .args import args
from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)