from collections.abc import Iterable import operator as op import itertools as it from collections import defaultdict from typing import Iterator, TypeAlias import logging import sys logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) # Represents a DAG with maps a -> {descendents of a} RuleSet: TypeAlias = dict[str, set[str]] def parse_rules(rule_lines: Iterable[str]) -> RuleSet: """ Parse a sequence of poset rules into a rulset mapping id -> list of descendents. """ descendents = defaultdict(set) for line in rule_lines: if '|' not in line: raise ValueError(f"Provided line '{line}' is not a rule.") a, b = line.split('|') descendents[a].add(b) return descendents def check_line(line: list[str], ruleset: RuleSet) -> bool: """ Check whether the provided line satisfies the given rules. """ for i, char in enumerate(line): rules = ruleset[char] if any(prevchar in rules for prevchar in line[:i]): return False return True def sort_line(line: list[str], ruleset: RuleSet) -> list[str]: """ Sort a given list of items by the given ruleset. This actually returns the reverse sorted line. """ sorted_line = [] visited = set() while node := next((n for n in line if n not in visited), None): _depth_sort_visit(node, line, ruleset, sorted_line, visited) return sorted_line def _depth_sort_visit( node: str, line: list[str], ruleset: RuleSet, sorting: list[str], visited: set[str] ): """ Note we assume the graph is truly a DAG and has no cycles, or this will create and infinite loop. """ if node in visited: return for parent in ruleset[node].intersection(line): _depth_sort_visit(parent, line, ruleset, sorting, visited) visited.add(node) sorting.append(node) def get_middle(line: list[str]) -> str: """ Get the middle of the line, if it exists. """ if not len(line) % 2: raise ValueError("Even length line has no middle.") return line[len(line) // 2] def sum_correct_middles(lines: Iterator[str]): lines = map(op.methodcaller('strip'), lines) lines = it.dropwhile(op.not_, lines) ruleset = parse_rules(it.takewhile(bool, lines)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Ruleset: {ruleset}") total = 0 for line in lines: if line: parts = line.split(',') if check_line(parts, ruleset): total += int(get_middle(parts)) logger.debug(f"Valid line: {line}") else: logger.debug(f"Invalid line: {line}") return total def sum_incorrect_middles(lines: Iterator[str]): lines = map(op.methodcaller('strip'), lines) lines = it.dropwhile(op.not_, lines) ruleset = parse_rules(it.takewhile(bool, lines)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Ruleset: {ruleset}") total = 0 for line in lines: if line: parts = line.split(',') if not check_line(parts, ruleset): logger.debug(f"Invalid line: {line}") corrected = sort_line(parts, ruleset) logger.debug(f"Sorted line: {corrected}") total += int(get_middle(corrected)) return total def main(): with open(sys.argv[1]) as f: total = sum_correct_middles(f) print(f"Total middle-sum of correct lines is {total}") with open(sys.argv[1]) as f: total = sum_incorrect_middles(f) print(f"Total middle-sum of corrected incorrect lines is {total}") if __name__ == '__main__': main()