rewrite: Batch flush for WebhookLogger.

This commit is contained in:
2022-11-04 09:37:21 +02:00
parent fd04b825f2
commit 322f519640

View File

@@ -4,9 +4,10 @@ import asyncio
from logging.handlers import QueueListener, QueueHandler from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue from queue import SimpleQueue
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO
from contextvars import ContextVar from contextvars import ContextVar
from discord import AllowedMentions, Webhook from discord import AllowedMentions, Webhook, File
import aiohttp import aiohttp
from .config import conf from .config import conf
@@ -111,7 +112,7 @@ logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler): class LocalQueueHandler(QueueHandler):
def emit(self, record: logging.LogRecord) -> None: def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation # Removed the call to self.prepare(), handle task cancellation
try: try:
self.enqueue(record) self.enqueue(record)
@@ -122,67 +123,91 @@ class LocalQueueHandler(QueueHandler):
class WebHookHandler(logging.StreamHandler): class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, batch=False): def __init__(self, webhook_url, batch=False, loop=None):
super().__init__(self) super().__init__(self)
self.webhook_url = webhook_url self.webhook_url = webhook_url
self.batched = "" self.batched = ""
self.batch = batch self.batch = batch
self.loop = None self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
def get_loop(self): def get_loop(self):
if self.loop is None: if self.loop is None:
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop return self.loop
def emit(self, record): def emit(self, record):
self.get_loop().run_until_complete(self.post(record)) self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
asyncio.create_task(self.post(record))
async def post(self, record): async def post(self, record):
try: try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S") timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.levelname}][{record.app}][{record.context}][{record.action}][{timestamp}]" header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]\n"
message = record.msg message = header+record.msg
# TODO: Maybe send file instead of splitting?
# TODO: Reformat header a little
if len(message) > 1900: if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False) as_file = True
else: else:
blocks = [message] as_file = False
message = "```md\n{}\n```".format(message)
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log message(s) # Post the log message(s)
if self.batch: if self.batch:
if len(message) > 500: if len(message) > 1000:
await self._send_batched() await self._send_batched_now()
await self._send(*blocks) await self._send(message, as_file=as_file)
elif len(self.batched) + len(blocks[0]) > 500:
self.batched += blocks[0]
await self._send_batched()
else: else:
self.batched += blocks[0] self.batched += message
if len(self.batched) + len(message) > 1000:
await self._send_batched_now()
else: else:
await self._send(*blocks) asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex: except Exception as ex:
print(ex) print(ex)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(ex)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self): async def _send_batched(self):
if self.batched: if self.batched:
batched = self.batched batched = self.batched
self.batched = "" self.batched = ""
await self._send(batched) await self._send(batched)
async def _send(self, *blocks): async def _send(self, message, as_file=False):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
webhook = Webhook.from_url(self.webhook_url, session=session) webhook = Webhook.from_url(self.webhook_url, session=session)
for block in blocks: if as_file or len(message) > 2000:
await webhook.send(block) with StringIO(message) as fp:
fp.seek(0)
await webhook.send(file=File(fp, filename="logs.md"))
else:
await webhook.send(message)
handlers = [] handlers = []
@@ -201,77 +226,33 @@ if webhook := conf.logging['critical_log']:
handlers.append(handler) handlers.append(handler)
if handlers: if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
queue: SimpleQueue[logging.LogRecord] = SimpleQueue() queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
handler = QueueHandler(queue) qhandler = QueueHandler(queue)
handler.setLevel(logging.INFO) qhandler.setLevel(logging.INFO)
handler.addFilter(ContextInjection()) qhandler.addFilter(ContextInjection())
logger.addHandler(handler) logger.addHandler(qhandler)
listener = QueueListener( listener = QueueListener(
queue, *handlers, respect_handler_level=True queue, *handlers, respect_handler_level=True
) )
listener.start() listener.start()
# QueueHandler to feed entries to a Queue
# On the other end of the Queue, feed to the webhook
# TODO: Add an async handler for posting
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
# Then we can handle error channels etc differently
# The formatting can be handled with a custom handler as well
# Define the context log format and attach it to the command logger as well
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]