diff --git a/src/drama_generator/generate.py b/src/drama_generator/generate.py index 7cd608c..d8a0224 100644 --- a/src/drama_generator/generate.py +++ b/src/drama_generator/generate.py @@ -3,8 +3,8 @@ import re import os import torch -from dotd.model import GRU -from dotd.data import get_shakespeare, preprocess +from drama_generator.model import GRU +from drama_generator.data import get_shakespeare, preprocess def get_header(): diff --git a/src/drama_generator/model.py b/src/drama_generator/model.py index d61fbd1..1148dfb 100644 --- a/src/drama_generator/model.py +++ b/src/drama_generator/model.py @@ -5,7 +5,7 @@ import click from torch.utils.data import DataLoader from tqdm import tqdm # type: ignore import numpy as np -from dotd.data import ShakespeareDataset, preprocess, get_shakespeare +from drama_generator.data import ShakespeareDataset, preprocess, get_shakespeare class GRU(torch.nn.Module):