From b0e4456270ff2628cdb3eae3871c23d5d69edd96 Mon Sep 17 00:00:00 2001 From: Tom Weber Date: Tue, 8 Feb 2022 13:27:53 +0100 Subject: [PATCH] click options --- src/dotd/generate.py | 9 +++++++-- src/dotd/model.py | 19 ++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/dotd/generate.py b/src/dotd/generate.py index 0e0f45c..beb7b88 100644 --- a/src/dotd/generate.py +++ b/src/dotd/generate.py @@ -3,6 +3,7 @@ import re import os import torch +import click from dotd.model import GRU from dotd.data import get_shakespeare, preprocess @@ -57,13 +58,17 @@ def write_drama(seq, temp=0.7, max_seq_len=1000): return seq -def main(): +@click.command() +@click.option( + "--output", type=str, default="./dramaoftheday.md", help="Output file path" +) +def main(output): print(__file__) """main function""" init_seq = "TOM:" output_text = get_header() + make_pretty(write_drama(init_seq)) with open( - os.path.join(os.path.dirname(__file__), "dramaoftheday.md"), + output, "w", encoding="utf-8", ) as file: diff --git a/src/dotd/model.py b/src/dotd/model.py index c7df003..d61fbd1 100644 --- a/src/dotd/model.py +++ b/src/dotd/model.py @@ -1,6 +1,7 @@ # coding: utf-8 """ module for training and creating the model """ import torch +import click from torch.utils.data import DataLoader from tqdm import tqdm # type: ignore import numpy as np @@ -55,7 +56,7 @@ def train_epoch(model, loader, optim, loss, device): return epoch_loss -def train(epochs: int): +def train(epochs: int, save_path: str): _, _, alphabet_size, _ = preprocess( get_shakespeare() ) # get amount of characters for one-hot to embedding @@ -73,8 +74,20 @@ def train(epochs: int): for e in tqdm(range(epochs)): l = train_epoch(gru, loader, optim, loss, device) print(f"Epoch: {e}, Loss: {l}") - torch.save(gru.state_dict(), "saved_models/gru_{}epochs.pth".format(epochs + 1)) + torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth") + + +@click.command() +@click.option( + "--save_path", + type=str, + default="./saved_models", + help="Folder to store the weights in.", +) +@click.option("--epochs", type=int, default=20, help="Number of epochs to train") +def main(epochs, save_path): + train(epochs, save_path) if __name__ == "__main__": - train(20) + main()