You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.3 KiB

# coding: utf-8
""" generate the drama of the day """
import re
import os
import torch
from dotd.model import GRU
from dotd.data import get_shakespeare, preprocess
def get_header():
"""get markdown header file for hugo static website generator"""
return open(
os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"),
"rt",
encoding="utf-8",
).read()
def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU:
"""load the model, put on cpu and in eval mode"""
gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model
gru.load_state_dict(
torch.load(
os.path.join(os.path.dirname(__file__), path),
map_location=torch.device("cpu"),
)
) # load weights
return gru.eval()
def make_pretty(text: str, html=False) -> str:
"""delete some line breaks for markdown"""
if html:
return re.sub(r"\n", r"<br>", text)
return re.sub(r"\n", r"<br>\n", text)
def write_drama(seq, temp=0.7, max_seq_len=1000):
"""generate the drama starting from a start sequence"""
int_to_letter, letter_to_int, alphabet_size, _ = preprocess(get_shakespeare())
gru = load_model(alphabet_size=alphabet_size)
hidden = gru.init_hidden()
input_idx = torch.LongTensor(
[[letter_to_int[s] for s in seq]]
) # input characters to ints
for _ in range(max_seq_len):
output, hidden = gru(
input_idx, hidden
) # predict the logits for the next character
pred = torch.squeeze(output, 0)[-1]
pred = pred / temp # apply temperature
pred_id = torch.distributions.categorical.Categorical(
logits=pred
).sample() # sample from the distribution
input_idx = torch.cat(
(input_idx[:, 1:], pred_id.reshape(1, -1)), 1
) # predicted character is added to our input
seq += int_to_letter[pred_id.item()] # add predicted character to sequence
return seq
def main(output):
print(__file__)
"""main function"""
init_seq = "TOM:"
output_text = get_header() + make_pretty(write_drama(init_seq))
with open(
output,
"w",
encoding="utf-8",
) as file:
file.write(output_text)
if __name__ == "__main__":
main()