Files
aoc2024/day5/solver.py
2024-12-05 22:12:27 +10:00

128 lines
3.6 KiB
Python

from collections.abc import Iterable
import operator as op
import itertools as it
from collections import defaultdict
from typing import Iterator, TypeAlias
import logging
import sys
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
# Represents a DAG with maps a -> {descendents of a}
RuleSet: TypeAlias = dict[str, set[str]]
def parse_rules(rule_lines: Iterable[str]) -> RuleSet:
"""
Parse a sequence of poset rules into a rulset
mapping id -> list of descendents.
"""
descendents = defaultdict(set)
for line in rule_lines:
if '|' not in line:
raise ValueError(f"Provided line '{line}' is not a rule.")
a, b = line.split('|')
descendents[a].add(b)
return descendents
def check_line(line: list[str], ruleset: RuleSet) -> bool:
"""
Check whether the provided line satisfies the given rules.
"""
for i, char in enumerate(line):
rules = ruleset[char]
if any(prevchar in rules for prevchar in line[:i]):
return False
return True
def sort_line(line: list[str], ruleset: RuleSet) -> list[str]:
"""
Sort a given list of items by the given ruleset.
This actually returns the reverse sorted line.
"""
sorted_line = []
visited = set()
while node := next((n for n in line if n not in visited), None):
_depth_sort_visit(node, line, ruleset, sorted_line, visited)
return sorted_line
def _depth_sort_visit(
node: str, line: list[str], ruleset: RuleSet,
sorting: list[str], visited: set[str]
):
"""
Note we assume the graph is truly a DAG and has no cycles,
or this will create and infinite loop.
"""
if node in visited:
return
for parent in ruleset[node].intersection(line):
_depth_sort_visit(parent, line, ruleset, sorting, visited)
visited.add(node)
sorting.append(node)
def get_middle(line: list[str]) -> str:
"""
Get the middle of the line, if it exists.
"""
if not len(line) % 2:
raise ValueError("Even length line has no middle.")
return line[len(line) // 2]
def sum_correct_middles(lines: Iterator[str]):
lines = map(op.methodcaller('strip'), lines)
lines = it.dropwhile(op.not_, lines)
ruleset = parse_rules(it.takewhile(bool, lines))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Ruleset: {ruleset}")
total = 0
for line in lines:
if line:
parts = line.split(',')
if check_line(parts, ruleset):
total += int(get_middle(parts))
logger.debug(f"Valid line: {line}")
else:
logger.debug(f"Invalid line: {line}")
return total
def sum_incorrect_middles(lines: Iterator[str]):
lines = map(op.methodcaller('strip'), lines)
lines = it.dropwhile(op.not_, lines)
ruleset = parse_rules(it.takewhile(bool, lines))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Ruleset: {ruleset}")
total = 0
for line in lines:
if line:
parts = line.split(',')
if not check_line(parts, ruleset):
logger.debug(f"Invalid line: {line}")
corrected = sort_line(parts, ruleset)
logger.debug(f"Sorted line: {corrected}")
total += int(get_middle(corrected))
return total
def main():
with open(sys.argv[1]) as f:
total = sum_correct_middles(f)
print(f"Total middle-sum of correct lines is {total}")
with open(sys.argv[1]) as f:
total = sum_incorrect_middles(f)
print(f"Total middle-sum of corrected incorrect lines is {total}")
if __name__ == '__main__':
main()