You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.2 KiB

import numpy as np
file = "inputs/day8.input"
numbers = [x.decode().replace("\n", "") for x in open(file, "rb").readlines()]
lens = [2, 3, 4, 5, 6, 7]
def solve_string(string):
solved_strings = {}
string = ["".join(sorted(x)) for x in string.split()]
unique = np.unique(string)
sorted_strings = sort_strings(unique)
solved_strings[1] = sorted_strings[2][0]
solved_strings[4] = sorted_strings[4][0]
solved_strings[7] = sorted_strings[3][0]
solved_strings[8] = sorted_strings[7][0]
sixes = solve_6s(sorted_strings[6], sorted_strings[2][0], sorted_strings[4][0])
for idx, number in enumerate(sixes):
solved_strings[number] = sorted_strings[6][idx]
fives = solve_5s(sorted_strings[5], solved_strings[6], solved_strings[7])
for idx, number in enumerate(fives):
solved_strings[number] = sorted_strings[5][idx]
return {v: k for k, v in solved_strings.items()}
def sort_strings(digits):
string_dict = {x: [] for x in lens}
for digit in digits:
if len(digit) > 1:
string_dict[len(digit)].append(digit)
return string_dict
def solve_5s(fives, cypher_six, cypher_seven):
solved_fives = []
for five in fives:
if all(letter in five for letter in cypher_seven):
solved_fives.append(3)
elif all(letter in cypher_six for letter in five):
solved_fives.append(5)
else:
solved_fives.append(2)
return solved_fives
def solve_6s(sixes, cypher_one, cypher_four):
solved_sixes = []
for six in sixes:
if any(letter not in six for letter in cypher_one):
solved_sixes.append(6)
elif any(letter not in six for letter in cypher_four):
solved_sixes.append(0)
else:
solved_sixes.append(9)
return solved_sixes
def decode_output(string):
output = string.split("|")[1]
solved_cyphers = solve_string(string)
decoded_output = []
for cypher in output.split():
decoded_output.append(solved_cyphers["".join(sorted(cypher))])
return int("".join(str(x) for x in decoded_output))
summed = 0
for number in numbers:
summed += decode_output(number)
print(summed)