Day 5 part 2 solution.

This commit is contained in:
2024-12-05 22:12:27 +10:00
parent 029a76b088
commit 921f03ec9d
2 changed files with 63 additions and 7 deletions

View File

@@ -10,6 +10,7 @@ logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
# Represents a DAG with maps a -> {descendents of a}
RuleSet: TypeAlias = dict[str, set[str]]
@@ -37,6 +38,35 @@ def check_line(line: list[str], ruleset: RuleSet) -> bool:
return True
def sort_line(line: list[str], ruleset: RuleSet) -> list[str]:
"""
Sort a given list of items by the given ruleset.
This actually returns the reverse sorted line.
"""
sorted_line = []
visited = set()
while node := next((n for n in line if n not in visited), None):
_depth_sort_visit(node, line, ruleset, sorted_line, visited)
return sorted_line
def _depth_sort_visit(
node: str, line: list[str], ruleset: RuleSet,
sorting: list[str], visited: set[str]
):
"""
Note we assume the graph is truly a DAG and has no cycles,
or this will create and infinite loop.
"""
if node in visited:
return
for parent in ruleset[node].intersection(line):
_depth_sort_visit(parent, line, ruleset, sorting, visited)
visited.add(node)
sorting.append(node)
def get_middle(line: list[str]) -> str:
"""
Get the middle of the line, if it exists.
@@ -45,7 +75,7 @@ def get_middle(line: list[str]) -> str:
raise ValueError("Even length line has no middle.")
return line[len(line) // 2]
def process_data(lines: Iterator[str]):
def sum_correct_middles(lines: Iterator[str]):
lines = map(op.methodcaller('strip'), lines)
lines = it.dropwhile(op.not_, lines)
ruleset = parse_rules(it.takewhile(bool, lines))
@@ -64,11 +94,33 @@ def process_data(lines: Iterator[str]):
return total
def sum_incorrect_middles(lines: Iterator[str]):
lines = map(op.methodcaller('strip'), lines)
lines = it.dropwhile(op.not_, lines)
ruleset = parse_rules(it.takewhile(bool, lines))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Ruleset: {ruleset}")
total = 0
for line in lines:
if line:
parts = line.split(',')
if not check_line(parts, ruleset):
logger.debug(f"Invalid line: {line}")
corrected = sort_line(parts, ruleset)
logger.debug(f"Sorted line: {corrected}")
total += int(get_middle(corrected))
return total
def main():
with open(sys.argv[1]) as f:
total = process_data(f)
print(f"Total middle-sum is {total}")
total = sum_correct_middles(f)
print(f"Total middle-sum of correct lines is {total}")
with open(sys.argv[1]) as f:
total = sum_incorrect_middles(f)
print(f"Total middle-sum of corrected incorrect lines is {total}")
if __name__ == '__main__':

View File

@@ -1,5 +1,5 @@
import logging
from solver import process_data
from solver import sum_correct_middles, sum_incorrect_middles
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
@@ -37,10 +37,14 @@ test_data = r"""
def test_main():
print("Beginning basic test")
result = process_data(iter(test_data.splitlines()))
print("Beginning correct middle sum test")
result = sum_correct_middles(iter(test_data.splitlines()))
assert result == 143
print("Basic test passed")
print("Correct middle sum test passed")
print("Beginning incorrect middle sum test")
result = sum_incorrect_middles(iter(test_data.splitlines()))
assert result == 123
print("Incorrect middle sum test passed")
if __name__ == '__main__':
test_main()