|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
# coding: utf-8
|
|
|
|
|
|
|
|
|
|
""" generate the drama of the day """
|
|
|
|
|
import re
|
|
|
|
|
import torch
|
|
|
|
|
from dotd.model import GRU
|
|
|
|
@ -7,10 +7,12 @@ from dotd.data import get_shakespeare, preprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_header():
|
|
|
|
|
return open("data/markdown_header.txt", "rt").read()
|
|
|
|
|
"""get markdown header file for hugo static website generator"""
|
|
|
|
|
return open("data/markdown_header.txt", "rt", encoding="utf-8").read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU:
|
|
|
|
|
"""load the model, put on cpu and in eval mode"""
|
|
|
|
|
gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model
|
|
|
|
|
gru.load_state_dict(
|
|
|
|
|
torch.load(
|
|
|
|
@ -22,17 +24,19 @@ def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_pretty(text: str) -> str:
|
|
|
|
|
"""delete some line breaks for markdown"""
|
|
|
|
|
return re.sub(r"\n", r"<br>\n", text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_drama(seq, temp=0.7, max_seq_len=1000):
|
|
|
|
|
"""generate the drama starting from a start sequence"""
|
|
|
|
|
int_to_letter, letter_to_int, alphabet_size, _ = preprocess(get_shakespeare())
|
|
|
|
|
gru = load_model(alphabet_size=alphabet_size)
|
|
|
|
|
hidden = gru.init_hidden()
|
|
|
|
|
input_idx = torch.LongTensor(
|
|
|
|
|
[[letter_to_int[s] for s in seq]]
|
|
|
|
|
) # input characters to ints
|
|
|
|
|
for i in range(max_seq_len):
|
|
|
|
|
for _ in range(max_seq_len):
|
|
|
|
|
output, hidden = gru(
|
|
|
|
|
input_idx, hidden
|
|
|
|
|
) # predict the logits for the next character
|
|
|
|
@ -49,10 +53,11 @@ def write_drama(seq, temp=0.7, max_seq_len=1000):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
"""main function"""
|
|
|
|
|
init_seq = "TOM:"
|
|
|
|
|
output_text = get_header() + make_pretty(write_drama(init_seq))
|
|
|
|
|
with open("dramaoftheday.md", "w") as f:
|
|
|
|
|
f.write(output_text)
|
|
|
|
|
with open("dramaoftheday.md", "w", encoding="utf-8") as file:
|
|
|
|
|
file.write(output_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|