data imports

master
Tom Weber 3 years ago
parent d1c27f2112
commit 69aec00459

@ -13,7 +13,10 @@ clean-test:
rm -f .coverage
rm -f .coverage.*
clean: clean-pyc clean-test
clean-drama:
rm -f ./src/dotd/dramaoftheday.md
clean: clean-pyc clean-test clean-drama
test: clean
. .venv/bin/activate && py.test tests --cov=src --cov-report=term-missing --cov-fail-under 95

@ -0,0 +1,20 @@
from setuptools import setup, find_packages
setup(
name="dotd",
version="0.1.0",
description="Generating gibberish shakespeare",
author="Tom Weber",
author_email="tom@weber.codes",
packages=find_packages(where="src"),
package_dir={"": "src"},
include_package_data=True,
package_data={
"dotd": [
"saved_models/rnn_2epochs.pth",
"data/shakespeare.txt",
"data/markdown_header.txt",
]
},
install_requires=["wheel", "torch", "numpy"],
)

@ -1,12 +1,17 @@
# coding: utf-8
""" data module """
import torch
import os
from torch.utils.data import Dataset
def get_shakespeare():
"""loads the shakespeare text"""
return open("data/shakespeare.txt", "rt", encoding="utf-8").read()
return open(
os.path.join(os.path.dirname(__file__), "data/shakespeare.txt"),
"rt",
encoding="utf-8",
).read()
def preprocess(text: str) -> tuple:

@ -1,6 +1,7 @@
# coding: utf-8
""" generate the drama of the day """
import re
import os
import torch
from dotd.model import GRU
from dotd.data import get_shakespeare, preprocess
@ -8,7 +9,11 @@ from dotd.data import get_shakespeare, preprocess
def get_header():
"""get markdown header file for hugo static website generator"""
return open("data/markdown_header.txt", "rt", encoding="utf-8").read()
return open(
os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"),
"rt",
encoding="utf-8",
).read()
def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU:
@ -16,7 +21,7 @@ def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -
gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model
gru.load_state_dict(
torch.load(
path,
os.path.join(os.path.dirname(__file__), path),
map_location=torch.device("cpu"),
)
) # load weights
@ -53,10 +58,15 @@ def write_drama(seq, temp=0.7, max_seq_len=1000):
def main():
print(__file__)
"""main function"""
init_seq = "TOM:"
output_text = get_header() + make_pretty(write_drama(init_seq))
with open("dramaoftheday.md", "w", encoding="utf-8") as file:
with open(
os.path.join(os.path.dirname(__file__), "dramaoftheday.md"),
"w",
encoding="utf-8",
) as file:
file.write(output_text)

Loading…
Cancel
Save