179 lines
6.8 KiB
Python
179 lines
6.8 KiB
Python
import asyncio
|
|
import bisect
|
|
import logging
|
|
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
|
|
|
|
from .lib import utc_now
|
|
from .ratelimits import Bucket
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Taskid = TypeVar('Taskid')
|
|
|
|
|
|
class TaskMonitor(Generic[Taskid]):
|
|
"""
|
|
Base class for a task monitor.
|
|
|
|
Stores tasks as a time-sorted list of taskids.
|
|
Subclasses may override `run_task` to implement an executor.
|
|
|
|
Adding or removing a single task has O(n) performance.
|
|
To bulk update tasks, instead use `schedule_tasks`.
|
|
|
|
Each taskid must be unique and hashable.
|
|
"""
|
|
|
|
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
|
|
# Ratelimit bucket to enforce maximum execution rate
|
|
self._bucket = bucket
|
|
|
|
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
|
|
|
|
self._wakeup: asyncio.Event = asyncio.Event()
|
|
self._monitor_task: Optional[self.Task] = None
|
|
|
|
# Task data
|
|
self._tasklist: list[Taskid] = []
|
|
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
|
|
|
|
# Running map ensures we keep a reference to the running task
|
|
# And allows simpler external cancellation if required
|
|
self._running: dict[Taskid, asyncio.Future] = {}
|
|
|
|
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
|
"""
|
|
Similar to `schedule_tasks`, but wipe and reset the tasklist.
|
|
"""
|
|
self._taskmap = {tid: time for tid, time in tasks}
|
|
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * tid * self._taskmap[tid])
|
|
self._wakeup.set()
|
|
|
|
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
|
"""
|
|
Schedule the given tasks.
|
|
|
|
Rather than repeatedly inserting tasks,
|
|
where the O(log n) insort is dominated by the O(n) list insertion,
|
|
we build an entirely new list, and always wake up the loop.
|
|
"""
|
|
self._taskmap |= {tid: time for tid, time in tasks}
|
|
self._tasklist = sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid])
|
|
self._wakeup.set()
|
|
|
|
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
|
|
"""
|
|
Insert the provided task into the tasklist.
|
|
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
|
|
"""
|
|
if self._tasklist:
|
|
nextid = self._tasklist[-1]
|
|
wake = self._taskmap[nextid] >= timestamp
|
|
wake = wake or taskid == nextid
|
|
else:
|
|
wake = False
|
|
if taskid in self._taskmap:
|
|
self._tasklist.remove(taskid)
|
|
self._taskmap[taskid] = timestamp
|
|
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
|
|
if wake:
|
|
self._wakeup.set()
|
|
|
|
def cancel_tasks(self, *taskids: Taskid) -> None:
|
|
"""
|
|
Remove all tasks with the given taskids from the tasklist.
|
|
If the next task has this taskid, wake up the monitor loop.
|
|
"""
|
|
taskids = set(taskids)
|
|
wake = (self._tasklist and self._tasklist[-1] in taskids)
|
|
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
|
|
for tid in taskids:
|
|
self._taskmap.pop(tid, None)
|
|
if wake:
|
|
self._wakeup.set()
|
|
|
|
def start(self):
|
|
if self._monitor_task and not self._monitor_task.done():
|
|
self._monitor_task.cancel()
|
|
# Start the monitor
|
|
self._monitor_task = asyncio.create_task(self.monitor())
|
|
return self._monitor_task
|
|
|
|
async def monitor(self):
|
|
"""
|
|
Start the monitor.
|
|
Executes the tasks in `self.tasks` at the specified time.
|
|
|
|
This will shield task execution from cancellation
|
|
to avoid partial states.
|
|
"""
|
|
try:
|
|
while True:
|
|
self._wakeup.clear()
|
|
if not self._tasklist:
|
|
# No tasks left, just sleep until wakeup
|
|
await self._wakeup.wait()
|
|
else:
|
|
# Get the next task, sleep until wakeup or it is ready to run
|
|
nextid = self._tasklist[-1]
|
|
nexttime = self._taskmap[nextid]
|
|
sleep_for = nexttime - utc_now().timestamp()
|
|
try:
|
|
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
|
|
except asyncio.TimeoutError:
|
|
# Ready to run the task
|
|
self._tasklist.pop()
|
|
self._taskmap.pop(nextid, None)
|
|
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
|
|
else:
|
|
# Wakeup task fired, loop again
|
|
continue
|
|
except asyncio.CancelledError:
|
|
# Log closure and wait for remaining tasks
|
|
# A second cancellation will also cancel the tasks
|
|
logger.debug(
|
|
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
|
|
f"Waiting for {len(self._running)} running tasks to complete."
|
|
)
|
|
await asyncio.gather(*self._running.values(), return_exceptions=True)
|
|
|
|
async def _run(self, taskid: Taskid) -> None:
|
|
# Execute the task, respecting the ratelimit bucket
|
|
if self._bucket is not None:
|
|
# IMPLEMENTATION NOTE:
|
|
# Bucket.wait() should guarantee not more than n tasks/second are run
|
|
# and that a request directly afterwards will _not_ raise BucketFull
|
|
# Make sure that only one waiter is actually waiting on its sleep task
|
|
# The other waiters should be waiting on a lock around the sleep task
|
|
# Waiters are executed in wait-order, so if we only let a single waiter in
|
|
# we shouldn't get collisions.
|
|
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
|
|
# Or we will lose thread-safety for BucketFull
|
|
await self._bucket.wait()
|
|
fut = asyncio.create_task(self.run_task(taskid))
|
|
try:
|
|
await asyncio.shield(fut)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
# Protect the monitor loop from any other exceptions
|
|
logger.exception(
|
|
f"Ignoring exception in task monitor {self.__class__.__name__} while "
|
|
f"executing <taskid: {taskid}>"
|
|
)
|
|
finally:
|
|
self._running.pop(taskid)
|
|
|
|
async def run_task(self, taskid: Taskid):
|
|
"""
|
|
Execute the task with the given taskid.
|
|
|
|
Default implementation executes `self.executor` if it exists,
|
|
otherwise raises NotImplementedError.
|
|
"""
|
|
if self.executor is not None:
|
|
await self.executor(taskid)
|
|
else:
|
|
raise NotImplementedError
|