import itertools from typing import Iterator import sys import re test_data = r""" MMMSXXMASM MSAMXMSMSA AMXSXMAAMM MSAMASMSMX XMASAMXAMM XXAMMXXAMA SMSMSASXSS SAXAMASAAA MAMMMXMMMM MXMXAXMASX """.strip() XMAS_RULES = list(zip( *[( (0, i), (0, -i), (i, 0), (-i, 0), (i, i), (-i, -i), (i, -i), (-i, i) ) for i in range(4)] )) MAS_RULES = [ ( (-1, -1), (0, 0), (1, 1) ), ( (-1, 1), (0, 0), (1, -1) ) ] DEBUG = True def print_debug_matrix(rows, cols, letter_map: dict[tuple[int, int], str]): for i in range(rows): for j in range(cols): print(letter_map.get((i, j), '.'), end='') print('\n') def test_XMAS(): print("Beginning test matrix search") test_matrix = list(filter(bool,[line.strip() for line in test_data.splitlines()])) assert letter_matrix_search_XMAS(test_matrix) == 18 print("Test matrix search passed") def test_MAS(): print("Beginning MAS matrix search") test_matrix = list(filter(bool,[line.strip() for line in test_data.splitlines()])) assert letter_matrix_search_MAS(test_matrix) == 9 print("MAS matrix search passed") def tests(): test_XMAS() test_MAS() def matrix_neighbour_rule(matrix: list[str], rowi, colj, rules: list[tuple[tuple[int, int], ...]]): """ Get the neighbours of the given point (rowi, colj) by the given rules. E.g. if rules = [((-1, 0), (0, 0), (0, 1))], get a forwards diagonal. Does not return a ruleset if it would go out of bounds. """ star = [] rows, cols = len(matrix), len(matrix[0]) for ruleset in rules: this_line = [] for rowstep, colstep in ruleset: rowii = rowi + rowstep colji = colj + colstep if not (0 <= rowii < rows and 0 <= colji < cols): break this_line.append(((rowii, colji), matrix[rowii][colji])) else: star.append(this_line) return star def letter_matrix_search_XMAS(matrix: list[str], DEBUG=DEBUG): total = 0 debug_map = {} for rowi, row in enumerate(matrix): for colj, letter in enumerate(row): if letter == 'X': star = matrix_neighbour_rule(matrix, rowi, colj, XMAS_RULES) for line in star: if ''.join(letter for _, letter in line) == 'XMAS': total += 1 debug_map |= dict(line) if DEBUG: print_debug_matrix(len(matrix), len(matrix[0]), debug_map) return total def letter_matrix_search_MAS(matrix: list[str], DEBUG=DEBUG): total = 0 debug_map = {} for rowi, row in enumerate(matrix): for colj, letter in enumerate(row): if letter == 'A': star = matrix_neighbour_rule(matrix, rowi, colj, MAS_RULES) valid = star and all(''.join(list(zip(*line))[1]) in ('MAS', 'SAM') for line in star) if valid: total += 1 for line in star: debug_map |= dict(line) if DEBUG: print_debug_matrix(len(matrix), len(matrix[0]), debug_map) return total def load_data(filename) -> list[str]: with open(filename) as f: return list(filter(bool, (line.strip() for line in f.readlines()))) def main(): tests() data = load_data(sys.argv[1]) total = letter_matrix_search_XMAS(data, DEBUG=False) print(f"Total XMAS found: {total}") total = letter_matrix_search_MAS(data, DEBUG=False) print(f"Total MAS found: {total}") if __name__ == '__main__': main()