128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
from collections.abc import Iterable
|
|
import operator as op
|
|
import itertools as it
|
|
from collections import defaultdict
|
|
from typing import Iterator, TypeAlias
|
|
import logging
|
|
import sys
|
|
|
|
logger = logging.getLogger()
|
|
logger.addHandler(logging.StreamHandler())
|
|
|
|
|
|
# Represents a DAG with maps a -> {descendents of a}
|
|
RuleSet: TypeAlias = dict[str, set[str]]
|
|
|
|
|
|
def parse_rules(rule_lines: Iterable[str]) -> RuleSet:
|
|
"""
|
|
Parse a sequence of poset rules into a rulset
|
|
mapping id -> list of descendents.
|
|
"""
|
|
descendents = defaultdict(set)
|
|
for line in rule_lines:
|
|
if '|' not in line:
|
|
raise ValueError(f"Provided line '{line}' is not a rule.")
|
|
a, b = line.split('|')
|
|
descendents[a].add(b)
|
|
return descendents
|
|
|
|
def check_line(line: list[str], ruleset: RuleSet) -> bool:
|
|
"""
|
|
Check whether the provided line satisfies the given rules.
|
|
"""
|
|
for i, char in enumerate(line):
|
|
rules = ruleset[char]
|
|
if any(prevchar in rules for prevchar in line[:i]):
|
|
return False
|
|
|
|
return True
|
|
|
|
def sort_line(line: list[str], ruleset: RuleSet) -> list[str]:
|
|
"""
|
|
Sort a given list of items by the given ruleset.
|
|
This actually returns the reverse sorted line.
|
|
"""
|
|
sorted_line = []
|
|
visited = set()
|
|
|
|
while node := next((n for n in line if n not in visited), None):
|
|
_depth_sort_visit(node, line, ruleset, sorted_line, visited)
|
|
|
|
return sorted_line
|
|
|
|
def _depth_sort_visit(
|
|
node: str, line: list[str], ruleset: RuleSet,
|
|
sorting: list[str], visited: set[str]
|
|
):
|
|
"""
|
|
Note we assume the graph is truly a DAG and has no cycles,
|
|
or this will create and infinite loop.
|
|
"""
|
|
if node in visited:
|
|
return
|
|
for parent in ruleset[node].intersection(line):
|
|
_depth_sort_visit(parent, line, ruleset, sorting, visited)
|
|
|
|
visited.add(node)
|
|
sorting.append(node)
|
|
|
|
def get_middle(line: list[str]) -> str:
|
|
"""
|
|
Get the middle of the line, if it exists.
|
|
"""
|
|
if not len(line) % 2:
|
|
raise ValueError("Even length line has no middle.")
|
|
return line[len(line) // 2]
|
|
|
|
def sum_correct_middles(lines: Iterator[str]):
|
|
lines = map(op.methodcaller('strip'), lines)
|
|
lines = it.dropwhile(op.not_, lines)
|
|
ruleset = parse_rules(it.takewhile(bool, lines))
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logger.debug(f"Ruleset: {ruleset}")
|
|
|
|
total = 0
|
|
for line in lines:
|
|
if line:
|
|
parts = line.split(',')
|
|
if check_line(parts, ruleset):
|
|
total += int(get_middle(parts))
|
|
logger.debug(f"Valid line: {line}")
|
|
else:
|
|
logger.debug(f"Invalid line: {line}")
|
|
|
|
return total
|
|
|
|
def sum_incorrect_middles(lines: Iterator[str]):
|
|
lines = map(op.methodcaller('strip'), lines)
|
|
lines = it.dropwhile(op.not_, lines)
|
|
ruleset = parse_rules(it.takewhile(bool, lines))
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logger.debug(f"Ruleset: {ruleset}")
|
|
|
|
total = 0
|
|
for line in lines:
|
|
if line:
|
|
parts = line.split(',')
|
|
if not check_line(parts, ruleset):
|
|
logger.debug(f"Invalid line: {line}")
|
|
corrected = sort_line(parts, ruleset)
|
|
logger.debug(f"Sorted line: {corrected}")
|
|
total += int(get_middle(corrected))
|
|
|
|
return total
|
|
|
|
|
|
def main():
|
|
with open(sys.argv[1]) as f:
|
|
total = sum_correct_middles(f)
|
|
print(f"Total middle-sum of correct lines is {total}")
|
|
with open(sys.argv[1]) as f:
|
|
total = sum_incorrect_middles(f)
|
|
print(f"Total middle-sum of corrected incorrect lines is {total}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|