Day 5 part 2 solution.
This commit is contained in:
@@ -10,6 +10,7 @@ logger = logging.getLogger()
|
|||||||
logger.addHandler(logging.StreamHandler())
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
|
||||||
|
|
||||||
|
# Represents a DAG with maps a -> {descendents of a}
|
||||||
RuleSet: TypeAlias = dict[str, set[str]]
|
RuleSet: TypeAlias = dict[str, set[str]]
|
||||||
|
|
||||||
|
|
||||||
@@ -37,6 +38,35 @@ def check_line(line: list[str], ruleset: RuleSet) -> bool:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def sort_line(line: list[str], ruleset: RuleSet) -> list[str]:
|
||||||
|
"""
|
||||||
|
Sort a given list of items by the given ruleset.
|
||||||
|
This actually returns the reverse sorted line.
|
||||||
|
"""
|
||||||
|
sorted_line = []
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
while node := next((n for n in line if n not in visited), None):
|
||||||
|
_depth_sort_visit(node, line, ruleset, sorted_line, visited)
|
||||||
|
|
||||||
|
return sorted_line
|
||||||
|
|
||||||
|
def _depth_sort_visit(
|
||||||
|
node: str, line: list[str], ruleset: RuleSet,
|
||||||
|
sorting: list[str], visited: set[str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Note we assume the graph is truly a DAG and has no cycles,
|
||||||
|
or this will create and infinite loop.
|
||||||
|
"""
|
||||||
|
if node in visited:
|
||||||
|
return
|
||||||
|
for parent in ruleset[node].intersection(line):
|
||||||
|
_depth_sort_visit(parent, line, ruleset, sorting, visited)
|
||||||
|
|
||||||
|
visited.add(node)
|
||||||
|
sorting.append(node)
|
||||||
|
|
||||||
def get_middle(line: list[str]) -> str:
|
def get_middle(line: list[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Get the middle of the line, if it exists.
|
Get the middle of the line, if it exists.
|
||||||
@@ -45,7 +75,7 @@ def get_middle(line: list[str]) -> str:
|
|||||||
raise ValueError("Even length line has no middle.")
|
raise ValueError("Even length line has no middle.")
|
||||||
return line[len(line) // 2]
|
return line[len(line) // 2]
|
||||||
|
|
||||||
def process_data(lines: Iterator[str]):
|
def sum_correct_middles(lines: Iterator[str]):
|
||||||
lines = map(op.methodcaller('strip'), lines)
|
lines = map(op.methodcaller('strip'), lines)
|
||||||
lines = it.dropwhile(op.not_, lines)
|
lines = it.dropwhile(op.not_, lines)
|
||||||
ruleset = parse_rules(it.takewhile(bool, lines))
|
ruleset = parse_rules(it.takewhile(bool, lines))
|
||||||
@@ -64,11 +94,33 @@ def process_data(lines: Iterator[str]):
|
|||||||
|
|
||||||
return total
|
return total
|
||||||
|
|
||||||
|
def sum_incorrect_middles(lines: Iterator[str]):
|
||||||
|
lines = map(op.methodcaller('strip'), lines)
|
||||||
|
lines = it.dropwhile(op.not_, lines)
|
||||||
|
ruleset = parse_rules(it.takewhile(bool, lines))
|
||||||
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
|
logger.debug(f"Ruleset: {ruleset}")
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for line in lines:
|
||||||
|
if line:
|
||||||
|
parts = line.split(',')
|
||||||
|
if not check_line(parts, ruleset):
|
||||||
|
logger.debug(f"Invalid line: {line}")
|
||||||
|
corrected = sort_line(parts, ruleset)
|
||||||
|
logger.debug(f"Sorted line: {corrected}")
|
||||||
|
total += int(get_middle(corrected))
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
with open(sys.argv[1]) as f:
|
with open(sys.argv[1]) as f:
|
||||||
total = process_data(f)
|
total = sum_correct_middles(f)
|
||||||
print(f"Total middle-sum is {total}")
|
print(f"Total middle-sum of correct lines is {total}")
|
||||||
|
with open(sys.argv[1]) as f:
|
||||||
|
total = sum_incorrect_middles(f)
|
||||||
|
print(f"Total middle-sum of corrected incorrect lines is {total}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
12
day5/test.py
12
day5/test.py
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from solver import process_data
|
from solver import sum_correct_middles, sum_incorrect_middles
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
@@ -37,10 +37,14 @@ test_data = r"""
|
|||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main():
|
||||||
print("Beginning basic test")
|
print("Beginning correct middle sum test")
|
||||||
result = process_data(iter(test_data.splitlines()))
|
result = sum_correct_middles(iter(test_data.splitlines()))
|
||||||
assert result == 143
|
assert result == 143
|
||||||
print("Basic test passed")
|
print("Correct middle sum test passed")
|
||||||
|
print("Beginning incorrect middle sum test")
|
||||||
|
result = sum_incorrect_middles(iter(test_data.splitlines()))
|
||||||
|
assert result == 123
|
||||||
|
print("Incorrect middle sum test passed")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_main()
|
test_main()
|
||||||
|
|||||||
Reference in New Issue
Block a user