diff --git a/bot/modules/reminders/commands.py b/bot/modules/reminders/commands.py index 0bb98d60..c8637a04 100644 --- a/bot/modules/reminders/commands.py +++ b/bot/modules/reminders/commands.py @@ -3,6 +3,7 @@ import asyncio import datetime import discord +from meta import sharding from utils.lib import parse_dur, parse_ranges, multiselect_regex from .module import module @@ -55,7 +56,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders to remove!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] if not ctx.args: lines = [] @@ -209,7 +210,8 @@ async def cmd_remindme(ctx, flags): ) # Schedule reminder - reminder.schedule() + if sharding.shard_number == 0: + reminder.schedule() # Ack embed = discord.Embed( @@ -231,7 +233,7 @@ async def cmd_remindme(ctx, flags): if not rows: return await ctx.reply("You have no reminders!") - live = Reminder.fetch(*(row.reminderid for row in rows)) + live = [Reminder(row.reminderid) for row in rows] lines = [] num_field = len(str(len(live) - 1)) diff --git a/bot/modules/reminders/reminder.py b/bot/modules/reminders/reminder.py index 73870341..67956a1d 100644 --- a/bot/modules/reminders/reminder.py +++ b/bot/modules/reminders/reminder.py @@ -1,8 +1,9 @@ import asyncio import datetime +import logging import discord -from meta import client +from meta import client, sharding from utils.lib import strfdur from .data import reminders @@ -46,7 +47,10 @@ class Reminder: cls._live_reminders[reminderid].cancel() # Remove from data - reminders.delete_where(reminderid=reminderids) + if reminderids: + return reminders.delete_where(reminderid=reminderids) + else: + return [] @property def data(self): @@ -134,10 +138,16 @@ class Reminder: """ Execute the reminder. """ + if not self.data: + # Reminder deleted elsewhere + return + if self.data.userid in client.user_blacklist(): self.delete(self.reminderid) return + userid = self.data.userid + # Build the message embed embed = discord.Embed( title="You asked me to remind you!", @@ -155,8 +165,26 @@ class Reminder: ) ) + # Update the reminder data, and reschedule if required + if self.data.interval: + next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) + rows = reminders.update_where( + {'remind_at': next_time}, + reminderid=self.reminderid + ) + self.schedule() + else: + rows = self.delete(self.reminderid) + if not rows: + # Reminder deleted elsewhere + return + # Send the message, if possible - user = self.user + if not (user := client.get_user(userid)): + try: + user = await client.fetch_user(userid) + except discord.HTTPException: + pass if user: try: await user.send(embed=embed) @@ -164,21 +192,38 @@ class Reminder: # Nothing we can really do here. Maybe tell the user about their reminder next time? pass - # Update the reminder data, and reschedule if required - if self.data.interval: - next_time = self.data.remind_at + datetime.timedelta(seconds=self.data.interval) - reminders.update_where({'remind_at': next_time}, reminderid=self.reminderid) - self.schedule() - else: - self.delete(self.reminderid) + +async def reminder_poll(client): + """ + One client/shard must continually poll for new or deleted reminders. + """ + # TODO: Clean this up with database signals or IPC + while True: + await asyncio.sleep(60) + + client.log( + "Running new reminder poll.", + context="REMINDERS", + level=logging.DEBUG + ) + + rids = {row.reminderid for row in reminders.fetch_rows_where()} + + to_delete = (rid for rid in Reminder._live_reminders if rid not in rids) + Reminder.delete(*to_delete) + + [Reminder(rid).schedule() for rid in rids if rid not in Reminder._live_reminders] @module.launch_task async def schedule_reminders(client): - rows = reminders.fetch_rows_where() - for row in rows: - Reminder(row.reminderid).schedule() - client.log( - "Scheduled {} reminders.".format(len(rows)), - context="LAUNCH_REMINDERS" - ) + if sharding.shard_number == 0: + rows = reminders.fetch_rows_where() + for row in rows: + Reminder(row.reminderid).schedule() + client.log( + "Scheduled {} reminders.".format(len(rows)), + context="LAUNCH_REMINDERS" + ) + if sharding.sharded: + asyncio.create_task(reminder_poll(client))