diff --git a/src/drama_generator/model.py b/src/drama_generator/model.py index 1148dfb..f7c1fca 100644 --- a/src/drama_generator/model.py +++ b/src/drama_generator/model.py @@ -1,7 +1,6 @@ # coding: utf-8 """ module for training and creating the model """ import torch -import click from torch.utils.data import DataLoader from tqdm import tqdm # type: ignore import numpy as np @@ -77,17 +76,3 @@ def train(epochs: int, save_path: str): torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth") -@click.command() -@click.option( - "--save_path", - type=str, - default="./saved_models", - help="Folder to store the weights in.", -) -@click.option("--epochs", type=int, default=20, help="Number of epochs to train") -def main(epochs, save_path): - train(epochs, save_path) - - -if __name__ == "__main__": - main()