#!/usr/bin/env python # coding: utf-8 import torch import re class RNN(torch.nn.Module): def __init__(self, vocab_size, hidden_size, embedding_size, batch=32, layers=2): super(RNN, self).__init__() self.hidden_size = hidden_size # size of the GRU layers self.batch = batch self.layers = layers # how many GRU layers self.word_embeds = torch.nn.Embedding(vocab_size, embedding_size) # Embedding layer self.gru = torch.nn.GRU(embedding_size, hidden_size, layers, batch_first=True) # GRU layer(s) self.output_layer = torch.nn.Linear(hidden_size, vocab_size) def forward(self, inputs, hidden): x = self.word_embeds(inputs) # transform the input integer into a high dimensional embedding output, hidden = self.gru(x, hidden) # Compute the output of the GRU layer(s) output = self.output_layer(output) # compute the logits return output, hidden def initHidden(self): return torch.zeros(self.layers, self.batch, self.hidden_size) def preprocess(text): alphabet = sorted(set(text)) letter_to_int = {let: ind for ind, let in enumerate(alphabet)} int_to_letter = {ind: let for ind, let in enumerate(alphabet)} letter_ints = [letter_to_int[letter] for letter in text] alphabet_size = len(alphabet) return int_to_letter, letter_to_int, alphabet_size, letter_ints def markdown_header(): return open("/home/tux/shakespeare_generator/markdown_header.txt", "rt").read() text = open("/home/tux/shakespeare_generator/shakespeare.txt", "rt").read() init_seq = "TOM:" int_to_letter, letter_to_int, alphabet_size, letter_ints = preprocess(text) rnn = RNN(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model rnn.load_state_dict(torch.load("/home/tux/shakespeare_generator/rnn_2epochs.pth", map_location=torch.device('cpu'))) # load weights rnn.eval() # tell model its time to evaluate def write_drama(seq, temp=0.7, max_seq_len=1000): hidden = rnn.initHidden() input_idx = torch.LongTensor([[letter_to_int[s] for s in seq]]) # input characters to ints for i in range(max_seq_len): output, hidden = rnn(input_idx, hidden) # predict the logits for the next character pred = torch.squeeze(output, 0)[-1] pred = pred / temp # apply temperature pred_id = torch.distributions.categorical.Categorical(logits=pred).sample() # sample from the distribution input_idx = torch.cat((input_idx[:,1:], pred_id.reshape(1,-1)), 1) # predicted character is added to our input seq += int_to_letter[pred_id.item()] # add predicted character to sequence return seq def stylise_drama(drama): return re.sub(r"\n", r"
\n", drama) output_text = markdown_header() + stylise_drama(write_drama(init_seq)) with open("/home/tux/shakespeare_generator/dramaoftheday.md", "w") as f: f.write(output_text)