diff --git a/day5/solver.py b/day5/solver.py index 894c4f8..9341543 100644 --- a/day5/solver.py +++ b/day5/solver.py @@ -10,6 +10,7 @@ logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) +# Represents a DAG with maps a -> {descendents of a} RuleSet: TypeAlias = dict[str, set[str]] @@ -37,6 +38,35 @@ def check_line(line: list[str], ruleset: RuleSet) -> bool: return True +def sort_line(line: list[str], ruleset: RuleSet) -> list[str]: + """ + Sort a given list of items by the given ruleset. + This actually returns the reverse sorted line. + """ + sorted_line = [] + visited = set() + + while node := next((n for n in line if n not in visited), None): + _depth_sort_visit(node, line, ruleset, sorted_line, visited) + + return sorted_line + +def _depth_sort_visit( + node: str, line: list[str], ruleset: RuleSet, + sorting: list[str], visited: set[str] +): + """ + Note we assume the graph is truly a DAG and has no cycles, + or this will create and infinite loop. + """ + if node in visited: + return + for parent in ruleset[node].intersection(line): + _depth_sort_visit(parent, line, ruleset, sorting, visited) + + visited.add(node) + sorting.append(node) + def get_middle(line: list[str]) -> str: """ Get the middle of the line, if it exists. @@ -45,7 +75,7 @@ def get_middle(line: list[str]) -> str: raise ValueError("Even length line has no middle.") return line[len(line) // 2] -def process_data(lines: Iterator[str]): +def sum_correct_middles(lines: Iterator[str]): lines = map(op.methodcaller('strip'), lines) lines = it.dropwhile(op.not_, lines) ruleset = parse_rules(it.takewhile(bool, lines)) @@ -64,11 +94,33 @@ def process_data(lines: Iterator[str]): return total +def sum_incorrect_middles(lines: Iterator[str]): + lines = map(op.methodcaller('strip'), lines) + lines = it.dropwhile(op.not_, lines) + ruleset = parse_rules(it.takewhile(bool, lines)) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Ruleset: {ruleset}") + + total = 0 + for line in lines: + if line: + parts = line.split(',') + if not check_line(parts, ruleset): + logger.debug(f"Invalid line: {line}") + corrected = sort_line(parts, ruleset) + logger.debug(f"Sorted line: {corrected}") + total += int(get_middle(corrected)) + + return total + def main(): with open(sys.argv[1]) as f: - total = process_data(f) - print(f"Total middle-sum is {total}") + total = sum_correct_middles(f) + print(f"Total middle-sum of correct lines is {total}") + with open(sys.argv[1]) as f: + total = sum_incorrect_middles(f) + print(f"Total middle-sum of corrected incorrect lines is {total}") if __name__ == '__main__': diff --git a/day5/test.py b/day5/test.py index 34607ed..a308dee 100644 --- a/day5/test.py +++ b/day5/test.py @@ -1,5 +1,5 @@ import logging -from solver import process_data +from solver import sum_correct_middles, sum_incorrect_middles logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -37,10 +37,14 @@ test_data = r""" def test_main(): - print("Beginning basic test") - result = process_data(iter(test_data.splitlines())) + print("Beginning correct middle sum test") + result = sum_correct_middles(iter(test_data.splitlines())) assert result == 143 - print("Basic test passed") + print("Correct middle sum test passed") + print("Beginning incorrect middle sum test") + result = sum_incorrect_middles(iter(test_data.splitlines())) + assert result == 123 + print("Incorrect middle sum test passed") if __name__ == '__main__': test_main()