click options

master
Tom Weber 3 years ago
parent 69aec00459
commit b0e4456270

@ -3,6 +3,7 @@
import re
import os
import torch
import click
from dotd.model import GRU
from dotd.data import get_shakespeare, preprocess
@ -57,13 +58,17 @@ def write_drama(seq, temp=0.7, max_seq_len=1000):
return seq
def main():
@click.command()
@click.option(
"--output", type=str, default="./dramaoftheday.md", help="Output file path"
)
def main(output):
print(__file__)
"""main function"""
init_seq = "TOM:"
output_text = get_header() + make_pretty(write_drama(init_seq))
with open(
os.path.join(os.path.dirname(__file__), "dramaoftheday.md"),
output,
"w",
encoding="utf-8",
) as file:

@ -1,6 +1,7 @@
# coding: utf-8
""" module for training and creating the model """
import torch
import click
from torch.utils.data import DataLoader
from tqdm import tqdm # type: ignore
import numpy as np
@ -55,7 +56,7 @@ def train_epoch(model, loader, optim, loss, device):
return epoch_loss
def train(epochs: int):
def train(epochs: int, save_path: str):
_, _, alphabet_size, _ = preprocess(
get_shakespeare()
) # get amount of characters for one-hot to embedding
@ -73,8 +74,20 @@ def train(epochs: int):
for e in tqdm(range(epochs)):
l = train_epoch(gru, loader, optim, loss, device)
print(f"Epoch: {e}, Loss: {l}")
torch.save(gru.state_dict(), "saved_models/gru_{}epochs.pth".format(epochs + 1))
torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth")
@click.command()
@click.option(
"--save_path",
type=str,
default="./saved_models",
help="Folder to store the weights in.",
)
@click.option("--epochs", type=int, default=20, help="Number of epochs to train")
def main(epochs, save_path):
train(epochs, save_path)
if __name__ == "__main__":
train(20)
main()

Loading…
Cancel
Save