# coding: utf-8 """ generate the drama of the day """ import re import os import torch from dotd.model import GRU from dotd.data import get_shakespeare, preprocess def get_header(): """get markdown header file for hugo static website generator""" return open( os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"), "rt", encoding="utf-8", ).read() def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU: """load the model, put on cpu and in eval mode""" gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model gru.load_state_dict( torch.load( os.path.join(os.path.dirname(__file__), path), map_location=torch.device("cpu"), ) ) # load weights return gru.eval() def make_pretty(text: str, html=False) -> str: """delete some line breaks for markdown""" if html: return re.sub(r"\n", r"
", text) return re.sub(r"\n", r"
\n", text) def write_drama(seq, temp=0.7, max_seq_len=1000): """generate the drama starting from a start sequence""" int_to_letter, letter_to_int, alphabet_size, _ = preprocess(get_shakespeare()) gru = load_model(alphabet_size=alphabet_size) hidden = gru.init_hidden() input_idx = torch.LongTensor( [[letter_to_int[s] for s in seq]] ) # input characters to ints for _ in range(max_seq_len): output, hidden = gru( input_idx, hidden ) # predict the logits for the next character pred = torch.squeeze(output, 0)[-1] pred = pred / temp # apply temperature pred_id = torch.distributions.categorical.Categorical( logits=pred ).sample() # sample from the distribution input_idx = torch.cat( (input_idx[:, 1:], pred_id.reshape(1, -1)), 1 ) # predicted character is added to our input seq += int_to_letter[pred_id.item()] # add predicted character to sequence return seq def main(output): print(__file__) """main function""" init_seq = "TOM:" output_text = get_header() + make_pretty(write_drama(init_seq)) with open( output, "w", encoding="utf-8", ) as file: file.write(output_text) if __name__ == "__main__": main()