click options

master
Tom Weber 3 years ago
parent 69aec00459
commit b0e4456270

@ -3,6 +3,7 @@
import re import re
import os import os
import torch import torch
import click
from dotd.model import GRU from dotd.model import GRU
from dotd.data import get_shakespeare, preprocess from dotd.data import get_shakespeare, preprocess
@ -57,13 +58,17 @@ def write_drama(seq, temp=0.7, max_seq_len=1000):
return seq return seq
def main(): @click.command()
@click.option(
"--output", type=str, default="./dramaoftheday.md", help="Output file path"
)
def main(output):
print(__file__) print(__file__)
"""main function""" """main function"""
init_seq = "TOM:" init_seq = "TOM:"
output_text = get_header() + make_pretty(write_drama(init_seq)) output_text = get_header() + make_pretty(write_drama(init_seq))
with open( with open(
os.path.join(os.path.dirname(__file__), "dramaoftheday.md"), output,
"w", "w",
encoding="utf-8", encoding="utf-8",
) as file: ) as file:

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
""" module for training and creating the model """ """ module for training and creating the model """
import torch import torch
import click
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm # type: ignore from tqdm import tqdm # type: ignore
import numpy as np import numpy as np
@ -55,7 +56,7 @@ def train_epoch(model, loader, optim, loss, device):
return epoch_loss return epoch_loss
def train(epochs: int): def train(epochs: int, save_path: str):
_, _, alphabet_size, _ = preprocess( _, _, alphabet_size, _ = preprocess(
get_shakespeare() get_shakespeare()
) # get amount of characters for one-hot to embedding ) # get amount of characters for one-hot to embedding
@ -73,8 +74,20 @@ def train(epochs: int):
for e in tqdm(range(epochs)): for e in tqdm(range(epochs)):
l = train_epoch(gru, loader, optim, loss, device) l = train_epoch(gru, loader, optim, loss, device)
print(f"Epoch: {e}, Loss: {l}") print(f"Epoch: {e}, Loss: {l}")
torch.save(gru.state_dict(), "saved_models/gru_{}epochs.pth".format(epochs + 1)) torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth")
@click.command()
@click.option(
"--save_path",
type=str,
default="./saved_models",
help="Folder to store the weights in.",
)
@click.option("--epochs", type=int, default=20, help="Number of epochs to train")
def main(epochs, save_path):
train(epochs, save_path)
if __name__ == "__main__": if __name__ == "__main__":
train(20) main()

Loading…
Cancel
Save