rewrite: Batch flush for WebhookLogger.

This commit is contained in:
2022-11-04 09:37:21 +02:00
parent fd04b825f2
commit 322f519640

View File

@@ -4,9 +4,10 @@ import asyncio
from logging.handlers import QueueListener, QueueHandler
from queue import SimpleQueue
from contextlib import contextmanager
from io import StringIO
from contextvars import ContextVar
from discord import AllowedMentions, Webhook
from discord import AllowedMentions, Webhook, File
import aiohttp
from .config import conf
@@ -111,7 +112,7 @@ logger.addHandler(logging_handler_err)
class LocalQueueHandler(QueueHandler):
def emit(self, record: logging.LogRecord) -> None:
def _emit(self, record: logging.LogRecord) -> None:
# Removed the call to self.prepare(), handle task cancellation
try:
self.enqueue(record)
@@ -122,67 +123,91 @@ class LocalQueueHandler(QueueHandler):
class WebHookHandler(logging.StreamHandler):
def __init__(self, webhook_url, batch=False):
def __init__(self, webhook_url, batch=False, loop=None):
super().__init__(self)
self.webhook_url = webhook_url
self.batched = ""
self.batch = batch
self.loop = None
self.loop = loop
self.batch_delay = 10
self.batch_task = None
self.last_batched = None
self.waiting = []
def get_loop(self):
if self.loop is None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
return self.loop
def emit(self, record):
self.get_loop().run_until_complete(self.post(record))
self.get_loop().call_soon_threadsafe(self._post, record)
def _post(self, record):
asyncio.create_task(self.post(record))
async def post(self, record):
try:
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
header = f"[{record.levelname}][{record.app}][{record.context}][{record.action}][{timestamp}]"
message = record.msg
header = f"[{timestamp}][{record.levelname}][{record.app}][{record.action}][{record.context}]\n"
message = header+record.msg
# TODO: Maybe send file instead of splitting?
# TODO: Reformat header a little
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
as_file = True
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
as_file = False
message = "```md\n{}\n```".format(message)
# Post the log message(s)
if self.batch:
if len(message) > 500:
await self._send_batched()
await self._send(*blocks)
elif len(self.batched) + len(blocks[0]) > 500:
self.batched += blocks[0]
await self._send_batched()
if len(message) > 1000:
await self._send_batched_now()
await self._send(message, as_file=as_file)
else:
self.batched += blocks[0]
self.batched += message
if len(self.batched) + len(message) > 1000:
await self._send_batched_now()
else:
await self._send(*blocks)
asyncio.create_task(self._schedule_batched())
else:
await self._send(message, as_file=as_file)
except Exception as ex:
print(ex)
async def _schedule_batched(self):
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
# noop, don't reschedule if it is already scheduled
return
try:
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
await self.batch_task
await self._send_batched()
except asyncio.CancelledError:
return
except Exception as ex:
print(ex)
async def _send_batched_now(self):
if self.batch_task is not None and not self.batch_task.done():
self.batch_task.cancel()
self.last_batched = None
await self._send_batched()
async def _send_batched(self):
if self.batched:
batched = self.batched
self.batched = ""
await self._send(batched)
async def _send(self, *blocks):
async def _send(self, message, as_file=False):
async with aiohttp.ClientSession() as session:
webhook = Webhook.from_url(self.webhook_url, session=session)
for block in blocks:
await webhook.send(block)
if as_file or len(message) > 2000:
with StringIO(message) as fp:
fp.seek(0)
await webhook.send(file=File(fp, filename="logs.md"))
else:
await webhook.send(message)
handlers = []
@@ -201,77 +226,33 @@ if webhook := conf.logging['critical_log']:
handlers.append(handler)
if handlers:
# First create a separate loop to run the handlers on
import threading
def run_loop(loop):
asyncio.set_event_loop(loop)
try:
loop.run_forever()
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
loop = asyncio.new_event_loop()
loop_thread = threading.Thread(target=lambda: run_loop(loop))
loop_thread.daemon = True
loop_thread.start()
for handler in handlers:
handler.loop = loop
queue: SimpleQueue[logging.LogRecord] = SimpleQueue()
handler = QueueHandler(queue)
handler.setLevel(logging.INFO)
handler.addFilter(ContextInjection())
logger.addHandler(handler)
qhandler = QueueHandler(queue)
qhandler.setLevel(logging.INFO)
qhandler.addFilter(ContextInjection())
logger.addHandler(qhandler)
listener = QueueListener(
queue, *handlers, respect_handler_level=True
)
listener.start()
# QueueHandler to feed entries to a Queue
# On the other end of the Queue, feed to the webhook
# TODO: Add an async handler for posting
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
# Then we can handle error channels etc differently
# The formatting can be handled with a custom handler as well
# Define the context log format and attach it to the command logger as well
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]