You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.3 KiB
77 lines
2.3 KiB
# coding: utf-8
|
|
""" generate the drama of the day """
|
|
import re
|
|
import os
|
|
import torch
|
|
from dotd.model import GRU
|
|
from dotd.data import get_shakespeare, preprocess
|
|
|
|
|
|
def get_header():
|
|
"""get markdown header file for hugo static website generator"""
|
|
return open(
|
|
os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"),
|
|
"rt",
|
|
encoding="utf-8",
|
|
).read()
|
|
|
|
|
|
def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU:
|
|
"""load the model, put on cpu and in eval mode"""
|
|
gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model
|
|
gru.load_state_dict(
|
|
torch.load(
|
|
os.path.join(os.path.dirname(__file__), path),
|
|
map_location=torch.device("cpu"),
|
|
)
|
|
) # load weights
|
|
return gru.eval()
|
|
|
|
|
|
def make_pretty(text: str, html=False) -> str:
|
|
"""delete some line breaks for markdown"""
|
|
if html:
|
|
return re.sub(r"\n", r"<br>", text)
|
|
return re.sub(r"\n", r"<br>\n", text)
|
|
|
|
|
|
def write_drama(seq, temp=0.7, max_seq_len=1000):
|
|
"""generate the drama starting from a start sequence"""
|
|
int_to_letter, letter_to_int, alphabet_size, _ = preprocess(get_shakespeare())
|
|
gru = load_model(alphabet_size=alphabet_size)
|
|
hidden = gru.init_hidden()
|
|
input_idx = torch.LongTensor(
|
|
[[letter_to_int[s] for s in seq]]
|
|
) # input characters to ints
|
|
for _ in range(max_seq_len):
|
|
output, hidden = gru(
|
|
input_idx, hidden
|
|
) # predict the logits for the next character
|
|
pred = torch.squeeze(output, 0)[-1]
|
|
pred = pred / temp # apply temperature
|
|
pred_id = torch.distributions.categorical.Categorical(
|
|
logits=pred
|
|
).sample() # sample from the distribution
|
|
input_idx = torch.cat(
|
|
(input_idx[:, 1:], pred_id.reshape(1, -1)), 1
|
|
) # predicted character is added to our input
|
|
seq += int_to_letter[pred_id.item()] # add predicted character to sequence
|
|
return seq
|
|
|
|
|
|
def main(output):
|
|
print(__file__)
|
|
"""main function"""
|
|
init_seq = "TOM:"
|
|
output_text = get_header() + make_pretty(write_drama(init_seq))
|
|
with open(
|
|
output,
|
|
"w",
|
|
encoding="utf-8",
|
|
) as file:
|
|
file.write(output_text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|