from collections.abc import Iterable import operator as op import itertools as it from collections import defaultdict from typing import Iterator, TypeAlias import logging import sys logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) RuleSet: TypeAlias = dict[str, set[str]] def parse_rules(rule_lines: Iterable[str]) -> RuleSet: """ Parse a sequence of poset rules into a rulset mapping id -> list of descendents. """ descendents = defaultdict(set) for line in rule_lines: if '|' not in line: raise ValueError(f"Provided line '{line}' is not a rule.") a, b = line.split('|') descendents[a].add(b) return descendents def check_line(line: list[str], ruleset: RuleSet) -> bool: """ Check whether the provided line satisfies the given rules. """ for i, char in enumerate(line): rules = ruleset[char] if any(prevchar in rules for prevchar in line[:i]): return False return True def get_middle(line: list[str]) -> str: """ Get the middle of the line, if it exists. """ if not len(line) % 2: raise ValueError("Even length line has no middle.") return line[len(line) // 2] def process_data(lines: Iterator[str]): lines = map(op.methodcaller('strip'), lines) lines = it.dropwhile(op.not_, lines) ruleset = parse_rules(it.takewhile(bool, lines)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Ruleset: {ruleset}") total = 0 for line in lines: if line: parts = line.split(',') if check_line(parts, ruleset): total += int(get_middle(parts)) logger.debug(f"Valid line: {line}") else: logger.debug(f"Invalid line: {line}") return total def main(): with open(sys.argv[1]) as f: total = process_data(f) print(f"Total middle-sum is {total}") if __name__ == '__main__': main()