132 lines
3.5 KiB
Python
132 lines
3.5 KiB
Python
import itertools
|
|
from typing import Iterator
|
|
import sys
|
|
import re
|
|
|
|
test_data = r"""
|
|
MMMSXXMASM
|
|
MSAMXMSMSA
|
|
AMXSXMAAMM
|
|
MSAMASMSMX
|
|
XMASAMXAMM
|
|
XXAMMXXAMA
|
|
SMSMSASXSS
|
|
SAXAMASAAA
|
|
MAMMMXMMMM
|
|
MXMXAXMASX
|
|
""".strip()
|
|
|
|
XMAS_RULES = list(zip(
|
|
*[(
|
|
(0, i), (0, -i),
|
|
(i, 0), (-i, 0),
|
|
(i, i), (-i, -i),
|
|
(i, -i), (-i, i)
|
|
) for i in range(4)]
|
|
))
|
|
MAS_RULES = [
|
|
( (-1, -1), (0, 0), (1, 1) ),
|
|
( (-1, 1), (0, 0), (1, -1) )
|
|
]
|
|
|
|
|
|
DEBUG = True
|
|
|
|
def print_debug_matrix(rows, cols, letter_map: dict[tuple[int, int], str]):
|
|
for i in range(rows):
|
|
for j in range(cols):
|
|
print(letter_map.get((i, j), '.'), end='')
|
|
print('\n')
|
|
|
|
def test_XMAS():
|
|
print("Beginning test matrix search")
|
|
test_matrix = list(filter(bool,[line.strip() for line in test_data.splitlines()]))
|
|
assert letter_matrix_search_XMAS(test_matrix) == 18
|
|
print("Test matrix search passed")
|
|
|
|
def test_MAS():
|
|
print("Beginning MAS matrix search")
|
|
test_matrix = list(filter(bool,[line.strip() for line in test_data.splitlines()]))
|
|
assert letter_matrix_search_MAS(test_matrix) == 9
|
|
print("MAS matrix search passed")
|
|
|
|
def tests():
|
|
test_XMAS()
|
|
test_MAS()
|
|
|
|
|
|
def matrix_neighbour_rule(matrix: list[str], rowi, colj, rules: list[tuple[tuple[int, int], ...]]):
|
|
"""
|
|
Get the neighbours of the given point (rowi, colj) by the given rules.
|
|
E.g. if rules = [((-1, 0), (0, 0), (0, 1))], get a forwards diagonal.
|
|
Does not return a ruleset if it would go out of bounds.
|
|
"""
|
|
star = []
|
|
rows, cols = len(matrix), len(matrix[0])
|
|
for ruleset in rules:
|
|
this_line = []
|
|
for rowstep, colstep in ruleset:
|
|
rowii = rowi + rowstep
|
|
colji = colj + colstep
|
|
|
|
if not (0 <= rowii < rows and 0 <= colji < cols):
|
|
break
|
|
|
|
this_line.append(((rowii, colji), matrix[rowii][colji]))
|
|
else:
|
|
star.append(this_line)
|
|
return star
|
|
|
|
|
|
def letter_matrix_search_XMAS(matrix: list[str], DEBUG=DEBUG):
|
|
total = 0
|
|
debug_map = {}
|
|
|
|
for rowi, row in enumerate(matrix):
|
|
for colj, letter in enumerate(row):
|
|
if letter == 'X':
|
|
star = matrix_neighbour_rule(matrix, rowi, colj, XMAS_RULES)
|
|
for line in star:
|
|
if ''.join(letter for _, letter in line) == 'XMAS':
|
|
total += 1
|
|
debug_map |= dict(line)
|
|
|
|
if DEBUG:
|
|
print_debug_matrix(len(matrix), len(matrix[0]), debug_map)
|
|
|
|
return total
|
|
|
|
def letter_matrix_search_MAS(matrix: list[str], DEBUG=DEBUG):
|
|
total = 0
|
|
debug_map = {}
|
|
|
|
for rowi, row in enumerate(matrix):
|
|
for colj, letter in enumerate(row):
|
|
if letter == 'A':
|
|
star = matrix_neighbour_rule(matrix, rowi, colj, MAS_RULES)
|
|
valid = star and all(''.join(list(zip(*line))[1]) in ('MAS', 'SAM') for line in star)
|
|
if valid:
|
|
total += 1
|
|
for line in star:
|
|
debug_map |= dict(line)
|
|
|
|
if DEBUG:
|
|
print_debug_matrix(len(matrix), len(matrix[0]), debug_map)
|
|
return total
|
|
|
|
def load_data(filename) -> list[str]:
|
|
with open(filename) as f:
|
|
return list(filter(bool, (line.strip() for line in f.readlines())))
|
|
|
|
def main():
|
|
tests()
|
|
data = load_data(sys.argv[1])
|
|
total = letter_matrix_search_XMAS(data, DEBUG=False)
|
|
print(f"Total XMAS found: {total}")
|
|
total = letter_matrix_search_MAS(data, DEBUG=False)
|
|
print(f"Total MAS found: {total}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|