sharding (reminders): Adapt for sharding.

Restrict reminder execution to shard `0`.
Add a poll on shard `0` to pick up new reminders.
Check whether the reminder still exists on execution.
This commit is contained in:
2021-12-22 20:24:24 +02:00
parent 0dd5213f13
commit e979e5cf45
2 changed files with 67 additions and 20 deletions

View File

@@ -3,6 +3,7 @@ import asyncio
import datetime import datetime
import discord import discord
from meta import sharding
from utils.lib import parse_dur, parse_ranges, multiselect_regex from utils.lib import parse_dur, parse_ranges, multiselect_regex
from .module import module from .module import module
@@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags):
if not rows: if not rows:
return await ctx.reply("You have no reminders to remove!") return await ctx.reply("You have no reminders to remove!")
live = Reminder.fetch(*(row.reminderid for row in rows)) live = [Reminder(row.reminderid) for row in rows]
if not ctx.args: if not ctx.args:
lines = [] lines = []
@@ -209,6 +210,7 @@ async def cmd_remindme(ctx, flags):
) )
# Schedule reminder # Schedule reminder
if sharding.shard_number == 0:
reminder.schedule() reminder.schedule()
# Ack # Ack
@@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags):
if not rows: if not rows:
return await ctx.reply("You have no reminders!") return await ctx.reply("You have no reminders!")
live = Reminder.fetch(*(row.reminderid for row in rows)) live = [Reminder(row.reminderid) for row in rows]
lines = [] lines = []
num_field = len(str(len(live) - 1)) num_field = len(str(len(live) - 1))

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
import datetime import datetime
import logging
import discord import discord
from meta import client from meta import client, sharding
from utils.lib import strfdur from utils.lib import strfdur
from .data import reminders from .data import reminders
@@ -46,7 +47,10 @@ class Reminder:
cls._live_reminders[reminderid].cancel() cls._live_reminders[reminderid].cancel()
# Remove from data # Remove from data
reminders.delete_where(reminderid=reminderids) if reminderids:
return reminders.delete_where(reminderid=reminderids)
else:
return []
@property @property
def data(self): def data(self):
@@ -134,10 +138,16 @@ class Reminder:
""" """
Execute the reminder. Execute the reminder.
""" """
if not self.data:
# Reminder deleted elsewhere
return
if self.data.userid in client.user_blacklist(): if self.data.userid in client.user_blacklist():
self.delete(self.reminderid) self.delete(self.reminderid)
return return
userid = self.data.userid
# Build the message embed # Build the message embed
embed = discord.Embed( embed = discord.Embed(
title="You asked me to remind you!", title="You asked me to remind you!",
@@ -155,8 +165,26 @@ class Reminder:
) )
) )
# Update the reminder data, and reschedule if required
if self.data.interval:
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval)
rows = reminders.update_where(
{'remind_at': next_time},
reminderid=self.reminderid
)
self.schedule()
else:
rows = self.delete(self.reminderid)
if not rows:
# Reminder deleted elsewhere
return
# Send the message, if possible # Send the message, if possible
user = self.user if not (user := client.get_user(userid)):
try:
user = await client.fetch_user(userid)
except discord.HTTPException:
pass
if user: if user:
try: try:
await user.send(embed=embed) await user.send(embed=embed)
@@ -164,17 +192,32 @@ class Reminder:
# Nothing we can really do here. Maybe tell the user about their reminder next time? # Nothing we can really do here. Maybe tell the user about their reminder next time?
pass pass
# Update the reminder data, and reschedule if required
if self.data.interval: async def reminder_poll(client):
next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) """
reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid) One client/shard must continually poll for new or deleted reminders.
self.schedule() """
else: # TODO: Clean this up with database signals or IPC
self.delete(self.reminderid) while True:
await asyncio.sleep(60)
client.log(
"Running new reminder poll.",
context="REMINDERS",
level=logging.DEBUG
)
rids = {row.reminderid for row in reminders.fetch_rows_where()}
to_delete = (rid for rid in Reminder._live_reminders if rid not in rids)
Reminder.delete(*to_delete)
[Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders]
@module.launch_task @module.launch_task
async def schedule_reminders(client): async def schedule_reminders(client):
if sharding.shard_number == 0:
rows = reminders.fetch_rows_where() rows = reminders.fetch_rows_where()
for row in rows: for row in rows:
Reminder(row.reminderid).schedule() Reminder(row.reminderid).schedule()
@@ -182,3 +225,5 @@ async def schedule_reminders(client):
"Scheduled {} reminders.".format(len(rows)), "Scheduled {} reminders.".format(len(rows)),
context="LAUNCH_REMINDERS" context="LAUNCH_REMINDERS"
) )
if sharding.sharded:
asyncio.create_task(reminder_poll(client))