|
|
|
@ -2,7 +2,6 @@
|
|
|
|
|
""" module for training and creating the model """
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm import tqdm # type: ignore
|
|
|
|
|
import numpy as np
|
|
|
|
|
from drama_generator.data import ShakespeareDataset, preprocess, get_shakespeare
|
|
|
|
|
|
|
|
|
@ -70,7 +69,7 @@ def train(epochs: int, save_path: str):
|
|
|
|
|
num_workers=2,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
)
|
|
|
|
|
for e in tqdm(range(epochs)):
|
|
|
|
|
for e in range(epochs):
|
|
|
|
|
l = train_epoch(gru, loader, optim, loss, device)
|
|
|
|
|
print(f"Epoch: {e}, Loss: {l}")
|
|
|
|
|
torch.save(gru.state_dict(), save_path + f"gru_{epochs+1}epochs.pth")
|
|
|
|
|