374 lines
13 KiB
Python
374 lines
13 KiB
Python
from typing import Optional
|
|
from weakref import WeakValueDictionary
|
|
import re
|
|
|
|
from meta import LionBot
|
|
from meta.errors import UserInputError
|
|
from utils.lib import utc_now
|
|
|
|
from . import babel
|
|
from .data import TasklistData
|
|
|
|
|
|
_p = babel._p
|
|
|
|
|
|
class Tasklist:
|
|
"""
|
|
Class representing a single user's tasklist.
|
|
|
|
Attributes
|
|
----------
|
|
bot: LionBot
|
|
Client which controls this tasklist.
|
|
data: TasklistData
|
|
Initialised tasklist data registry.
|
|
userid: int
|
|
The user who owns this tasklist.
|
|
tasklist: dict[int, TasklistData.Task]
|
|
A local cache map of tasks the user owns.
|
|
May or may not contain deleted tasks.
|
|
"""
|
|
_cache_ = WeakValueDictionary()
|
|
|
|
label_range_re = re.compile(
|
|
r"^(?P<start>(\d+\.)*\d+)\.?((\s*(?P<range>-)\s*)(?P<end>(\d+\.)*\d*\.?))?$"
|
|
)
|
|
line_regex = re.compile(r"(?P<depth>\s*)-?\s*(\[\s*(?P<check>[^]]?)\s*\]\s*)?(?P<content>.*)")
|
|
|
|
def __init__(self, bot: LionBot, data: TasklistData, userid: int):
|
|
self.bot = bot
|
|
self.data = data
|
|
self.userid = userid
|
|
|
|
self.tasklist: dict[int, TasklistData.Task] = {}
|
|
|
|
@classmethod
|
|
async def fetch(cls, bot: LionBot, data: TasklistData, userid: int) -> 'Tasklist':
|
|
"""
|
|
Fetch and initialise a Tasklist, using cache where possible.
|
|
"""
|
|
if userid not in cls._cache_:
|
|
cls = cls(bot, data, userid)
|
|
await cls.refresh()
|
|
cls._cache_[userid] = cls
|
|
return cls._cache_[userid]
|
|
|
|
def _label(self, task, taskmap, labels, counters) -> tuple[int, ...]:
|
|
tid = task.taskid
|
|
|
|
if tid in labels:
|
|
label = labels[tid]
|
|
else:
|
|
pid = task.parentid
|
|
if (parent := taskmap.get(pid, None)) is None:
|
|
# Case where parent is not a valid task
|
|
pid = None
|
|
if pid is not None and pid in counters and pid not in labels:
|
|
# Parent has started being labelled before
|
|
# But has not finished!
|
|
# This implies a cycle (i.e. tasklist is not a tree)
|
|
# Prune the cycle
|
|
pid = None
|
|
|
|
counters[pid] = i = counters.get(pid, 0) + 1
|
|
if pid is not None:
|
|
plabel = self._label(parent, taskmap, labels, counters)
|
|
else:
|
|
plabel = ()
|
|
labels[tid] = label = (*plabel, i)
|
|
return label
|
|
|
|
@property
|
|
def labelled(self) -> dict[tuple[int, ...], TasklistData.Task]:
|
|
"""
|
|
A sorted map of task string ids to tasks.
|
|
This is the tasklist that is visible to the user.
|
|
"""
|
|
taskmap = {
|
|
task.taskid: task
|
|
for task in sorted(self.tasklist.values(), key=lambda t: t.taskid)
|
|
if task.deleted_at is None
|
|
}
|
|
labels = {}
|
|
counters = {}
|
|
for task in taskmap.values():
|
|
self._label(task, taskmap, labels, counters)
|
|
labelmap = {
|
|
label: taskmap[taskid]
|
|
for taskid, label in sorted(labels.items(), key=lambda lt: lt[1])
|
|
}
|
|
return labelmap
|
|
|
|
def labelid(self, taskid) -> Optional[tuple[int, ...]]:
|
|
"""
|
|
Relatively expensive method to get the label for a given task, if it exists.
|
|
"""
|
|
task = self.tasklist.get(taskid, None)
|
|
if task is None:
|
|
return None
|
|
|
|
labelled = self.labelled
|
|
mapper = {t.taskid: label for label, t in labelled.items()}
|
|
return mapper[taskid]
|
|
|
|
async def refresh(self):
|
|
"""
|
|
Update the `tasklist` from data.
|
|
"""
|
|
tasks = await self.data.Task.fetch_where(userid=self.userid, deleted_at=None)
|
|
self.tasklist = {task.taskid: task for task in tasks}
|
|
|
|
async def _owner_check(self, *taskids: int) -> bool:
|
|
"""
|
|
Check whether all of the given tasks are owned by this tasklist user.
|
|
|
|
Applies cache where possible.
|
|
"""
|
|
missing = [tid for tid in taskids if tid not in self.tasklist]
|
|
if missing:
|
|
missing = [tid for tid in missing if (tid, ) not in self.data.Task._cache_]
|
|
if missing:
|
|
tasks = await self.data.Task.fetch_where(taskid=missing)
|
|
missing = [task.taskid for task in tasks if task.userid != self.userid]
|
|
|
|
return not bool(missing)
|
|
|
|
async def fetch_tasks(self, *taskids: int) -> list[TasklistData.Task]:
|
|
"""
|
|
Fetch the tasks from the tasklist with the given taskids.
|
|
|
|
Raises a ValueError if the tasks are not owned by the tasklist user.
|
|
"""
|
|
# Check the tasklist user owns all the tasks
|
|
# Also ensures the Row objects are in cache
|
|
if not await self._owner_check(*taskids):
|
|
raise ValueError("The given tasks are not in this tasklist!")
|
|
return [await self.data.Task.fetch(tid) for tid in taskids]
|
|
|
|
async def create_task(self, content: str, **kwargs) -> TasklistData.Task:
|
|
"""
|
|
Create a new task with the given content.
|
|
"""
|
|
task = await self.data.Task.create(userid=self.userid, content=content, **kwargs)
|
|
self.tasklist[task.taskid] = task
|
|
return task
|
|
|
|
async def update_tasks(self, *taskids: int, cascade=False, **kwargs):
|
|
"""
|
|
Update the given taskids with the provided new values.
|
|
|
|
If `cascade` is True, also applies the updates to all children.
|
|
"""
|
|
if not taskids:
|
|
raise ValueError("No tasks provided to update.")
|
|
|
|
if cascade:
|
|
taskids = self.children_cascade(*taskids)
|
|
|
|
# Ensure the taskids exist and belong to this user
|
|
await self.fetch_tasks(*taskids)
|
|
|
|
# Update the tasks
|
|
kwargs.setdefault('last_updated_at', utc_now())
|
|
tasks = await self.data.Task.table.update_where(
|
|
userid=self.userid,
|
|
taskid=taskids,
|
|
).set(**kwargs)
|
|
|
|
# Return the updated tasks
|
|
return tasks
|
|
|
|
async def update_tasklist(self, **kwargs):
|
|
"""
|
|
Update every task in the tasklist, regardless of cache.
|
|
"""
|
|
kwargs.setdefault('last_updated_at', utc_now())
|
|
tasks = await self.data.Task.table.update_where(
|
|
userid=self.userid,
|
|
deleted_at=None
|
|
).set(**kwargs)
|
|
|
|
return tasks
|
|
|
|
def children_cascade(self, *taskids) -> list[int]:
|
|
"""
|
|
Return the provided taskids with all their descendants.
|
|
Only checks the current tasklist cache for descendants.
|
|
"""
|
|
taskids = set(taskids)
|
|
added = True
|
|
while added:
|
|
added = False
|
|
for task in self.tasklist.values():
|
|
if task.deleted_at is None and task.taskid not in taskids and task.parentid in taskids:
|
|
taskids.add(task.taskid)
|
|
added = True
|
|
return list(taskids)
|
|
|
|
def parse_label(self, labelstr: str) -> Optional[int]:
|
|
"""
|
|
Parse a provided label string into a taskid, if it can be found.
|
|
|
|
Returns None if no matching taskids are found.
|
|
"""
|
|
splits = [s for s in labelstr.split('.') if s]
|
|
if all(split.isdigit() for split in splits):
|
|
tasks = self.labelled
|
|
label = tuple(map(int, splits))
|
|
if label in tasks:
|
|
return tasks[label].taskid
|
|
|
|
def format_label(self, label: tuple[int, ...]) -> str:
|
|
"""
|
|
Format the provided label tuple into the standard number format.
|
|
"""
|
|
return '.'.join(map(str, label)) + '.' * (len(label) == 1)
|
|
|
|
def parse_labels(self, labelstr: str) -> Optional[list[int]]:
|
|
"""
|
|
Parse a comma separated list of labels and label ranges into a list of labels.
|
|
|
|
E.g. `1, 2, 3`, `1, 2-5, 7`, `1, 2.1, 3`, `1, 2.1-3`, `1, 2.1-`
|
|
|
|
May raise `UserInputError`.
|
|
"""
|
|
labelmap = {label: task.taskid for label, task in self.labelled.items()}
|
|
|
|
splits = labelstr.split(',')
|
|
splits = [split.strip(' ,.') for split in splits]
|
|
splits = [split for split in splits if split]
|
|
|
|
taskids = set()
|
|
for split in splits:
|
|
match = self.label_range_re.match(split)
|
|
if match:
|
|
start = match['start']
|
|
ranged = match['range']
|
|
end = match['end']
|
|
|
|
start_label = tuple(map(int, start.split('.')))
|
|
head = start_label[:-1]
|
|
start_tail = start_label[-1]
|
|
|
|
if end:
|
|
end_label = tuple(map(int, end.split('.')))
|
|
end_tail = end_label[-1]
|
|
|
|
if len(end_label) > 1 and head != end_label[:-1]:
|
|
# Error: Parents don't match in range ...
|
|
t = self.bot.translator.t
|
|
raise UserInputError(
|
|
t(_p(
|
|
'tasklist|parse:multi-range|error:parents_match',
|
|
"Parents don't match in range `{range}`"
|
|
)).format(range=split)
|
|
)
|
|
|
|
for tail in range(max(start_tail, 1), end_tail + 1):
|
|
label = (*head, tail)
|
|
if label not in labelmap:
|
|
break
|
|
taskids.add(labelmap[label])
|
|
elif ranged:
|
|
# No end but still ranged
|
|
for label, taskid in labelmap.items():
|
|
if (label[:-1] == head) and (label[-1] >= start_tail):
|
|
taskids.add(taskid)
|
|
elif start_label in labelmap:
|
|
taskids.add(labelmap[start_label])
|
|
else:
|
|
# Error
|
|
t = self.bot.translator.t
|
|
raise UserInputError(
|
|
t(_p(
|
|
'tasklist|parse:multi-range|error:parse',
|
|
"Could not parse `{range}` as a task number or range."
|
|
)).format(range=split)
|
|
)
|
|
return list(taskids)
|
|
|
|
def flatten(self):
|
|
"""
|
|
Flatten the tasklist to a map of readable strings parseable by `parse_tasklist`.
|
|
"""
|
|
labelled = self.labelled
|
|
lines = {}
|
|
total_len = 0
|
|
for label, task in labelled.items():
|
|
prefix = ' ' * (len(label) - 1)
|
|
box = '- [ ]' if task.completed_at is None else '- [x]'
|
|
line = f"{prefix}{box} {task.content}"
|
|
lines[task.taskid] = line
|
|
total_len += len(line)
|
|
return lines
|
|
|
|
def parse_tasklist(self, task_lines):
|
|
t = self.bot.translator.t
|
|
taskinfo = [] # (parent, truedepth, ticked, content)
|
|
depthtree = [] # (depth, index)
|
|
|
|
for line in task_lines:
|
|
match = self.line_regex.match(line)
|
|
if not match:
|
|
raise UserInputError(
|
|
t(_p(
|
|
'modal:tasklist_bulk_editor|error:parse_task',
|
|
"Malformed taskline!\n`{input}`"
|
|
)).format(input=line)
|
|
)
|
|
depth = len(match['depth'])
|
|
check = bool(match['check'])
|
|
content = match['content']
|
|
if not content:
|
|
continue
|
|
if len(content) > 200:
|
|
raise UserInputError(
|
|
t(_p(
|
|
'modal:tasklist_bulk_editor|error:task_too_long',
|
|
"Please keep your tasks under 200 characters!"
|
|
))
|
|
)
|
|
|
|
for i in range(len(depthtree)):
|
|
lastdepth = depthtree[-1][0]
|
|
if lastdepth >= depth:
|
|
depthtree.pop()
|
|
if lastdepth <= depth:
|
|
break
|
|
parent = depthtree[-1][1] if depthtree else None
|
|
depthtree.append((depth, len(taskinfo)))
|
|
taskinfo.append((parent, len(depthtree) - 1, check, content))
|
|
print(taskinfo)
|
|
return taskinfo
|
|
|
|
async def write_taskinfo(self, taskinfo):
|
|
"""
|
|
Create tasks from `taskinfo` (matching the output of `parse_tasklist`).
|
|
"""
|
|
now = utc_now()
|
|
created = {}
|
|
target_depth = 0
|
|
while True:
|
|
to_insert = {}
|
|
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
|
|
if truedepth == target_depth:
|
|
to_insert[i] = (
|
|
self.userid,
|
|
content,
|
|
created[parent] if parent is not None else None,
|
|
now if ticked else None
|
|
)
|
|
if to_insert:
|
|
# Batch insert
|
|
tasks = await self.data.Task.table.insert_many(
|
|
('userid', 'content', 'parentid', 'completed_at'),
|
|
*to_insert.values()
|
|
)
|
|
for i, task in zip(to_insert.keys(), tasks):
|
|
created[i] = task['taskid']
|
|
target_depth += 1
|
|
else:
|
|
# Reached maximum depth
|
|
break
|