|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
# coding: utf-8
|
|
|
|
|
""" module for training and creating the model """
|
|
|
|
|
import torch
|
|
|
|
|
import click
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm import tqdm # type: ignore
|
|
|
|
|
import numpy as np
|
|
|
|
@ -77,17 +76,3 @@ def train(epochs: int, save_path: str):
|
|
|
|
|
torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.command()
|
|
|
|
|
@click.option(
|
|
|
|
|
"--save_path",
|
|
|
|
|
type=str,
|
|
|
|
|
default="./saved_models",
|
|
|
|
|
help="Folder to store the weights in.",
|
|
|
|
|
)
|
|
|
|
|
@click.option("--epochs", type=int, default=20, help="Number of epochs to train")
|
|
|
|
|
def main(epochs, save_path):
|
|
|
|
|
train(epochs, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|