|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
# coding: utf-8
|
|
|
|
|
""" module for training and creating the model """
|
|
|
|
|
import torch
|
|
|
|
|
import click
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm import tqdm # type: ignore
|
|
|
|
|
import numpy as np
|
|
|
|
@ -55,7 +56,7 @@ def train_epoch(model, loader, optim, loss, device):
|
|
|
|
|
return epoch_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(epochs: int):
|
|
|
|
|
def train(epochs: int, save_path: str):
|
|
|
|
|
_, _, alphabet_size, _ = preprocess(
|
|
|
|
|
get_shakespeare()
|
|
|
|
|
) # get amount of characters for one-hot to embedding
|
|
|
|
@ -73,8 +74,20 @@ def train(epochs: int):
|
|
|
|
|
for e in tqdm(range(epochs)):
|
|
|
|
|
l = train_epoch(gru, loader, optim, loss, device)
|
|
|
|
|
print(f"Epoch: {e}, Loss: {l}")
|
|
|
|
|
torch.save(gru.state_dict(), "saved_models/gru_{}epochs.pth".format(epochs + 1))
|
|
|
|
|
torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
|
|
|
@click.option(
|
|
|
|
|
"--save_path",
|
|
|
|
|
type=str,
|
|
|
|
|
default="./saved_models",
|
|
|
|
|
help="Folder to store the weights in.",
|
|
|
|
|
)
|
|
|
|
|
@click.option("--epochs", type=int, default=20, help="Number of epochs to train")
|
|
|
|
|
def main(epochs, save_path):
|
|
|
|
|
train(epochs, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
train(20)
|
|
|
|
|
main()
|
|
|
|
|