data imports

master
Tom Weber 3 years ago
parent d1c27f2112
commit 69aec00459

@ -13,7 +13,10 @@ clean-test:
rm -f .coverage rm -f .coverage
rm -f .coverage.* rm -f .coverage.*
clean: clean-pyc clean-test clean-drama:
rm -f ./src/dotd/dramaoftheday.md
clean: clean-pyc clean-test clean-drama
test: clean test: clean
. .venv/bin/activate && py.test tests --cov=src --cov-report=term-missing --cov-fail-under 95 . .venv/bin/activate && py.test tests --cov=src --cov-report=term-missing --cov-fail-under 95

@ -0,0 +1,20 @@
from setuptools import setup, find_packages
setup(
name="dotd",
version="0.1.0",
description="Generating gibberish shakespeare",
author="Tom Weber",
author_email="tom@weber.codes",
packages=find_packages(where="src"),
package_dir={"": "src"},
include_package_data=True,
package_data={
"dotd": [
"saved_models/rnn_2epochs.pth",
"data/shakespeare.txt",
"data/markdown_header.txt",
]
},
install_requires=["wheel", "torch", "numpy"],
)

@ -1,12 +1,17 @@
# coding: utf-8 # coding: utf-8
""" data module """ """ data module """
import torch import torch
import os
from torch.utils.data import Dataset from torch.utils.data import Dataset
def get_shakespeare(): def get_shakespeare():
"""loads the shakespeare text""" """loads the shakespeare text"""
return open("data/shakespeare.txt", "rt", encoding="utf-8").read() return open(
os.path.join(os.path.dirname(__file__), "data/shakespeare.txt"),
"rt",
encoding="utf-8",
).read()
def preprocess(text: str) -> tuple: def preprocess(text: str) -> tuple:

@ -1,6 +1,7 @@
# coding: utf-8 # coding: utf-8
""" generate the drama of the day """ """ generate the drama of the day """
import re import re
import os
import torch import torch
from dotd.model import GRU from dotd.model import GRU
from dotd.data import get_shakespeare, preprocess from dotd.data import get_shakespeare, preprocess
@ -8,7 +9,11 @@ from dotd.data import get_shakespeare, preprocess
def get_header(): def get_header():
"""get markdown header file for hugo static website generator""" """get markdown header file for hugo static website generator"""
return open("data/markdown_header.txt", "rt", encoding="utf-8").read() return open(
os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"),
"rt",
encoding="utf-8",
).read()
def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU: def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU:
@ -16,7 +21,7 @@ def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -
gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model
gru.load_state_dict( gru.load_state_dict(
torch.load( torch.load(
path, os.path.join(os.path.dirname(__file__), path),
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
) )
) # load weights ) # load weights
@ -53,10 +58,15 @@ def write_drama(seq, temp=0.7, max_seq_len=1000):
def main(): def main():
print(__file__)
"""main function""" """main function"""
init_seq = "TOM:" init_seq = "TOM:"
output_text = get_header() + make_pretty(write_drama(init_seq)) output_text = get_header() + make_pretty(write_drama(init_seq))
with open("dramaoftheday.md", "w", encoding="utf-8") as file: with open(
os.path.join(os.path.dirname(__file__), "dramaoftheday.md"),
"w",
encoding="utf-8",
) as file:
file.write(output_text) file.write(output_text)

Loading…
Cancel
Save