From 69aec0045970db92173a2e5b4395a89ccdfadaff Mon Sep 17 00:00:00 2001 From: Tom Weber Date: Tue, 8 Feb 2022 12:04:40 +0100 Subject: [PATCH] data imports --- Makefile | 5 ++++- setup.py | 20 ++++++++++++++++++++ src/dotd/data.py | 7 ++++++- src/dotd/generate.py | 16 +++++++++++++--- 4 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 setup.py diff --git a/Makefile b/Makefile index c5af3ab..6205296 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,10 @@ clean-test: rm -f .coverage rm -f .coverage.* -clean: clean-pyc clean-test +clean-drama: + rm -f ./src/dotd/dramaoftheday.md + +clean: clean-pyc clean-test clean-drama test: clean . .venv/bin/activate && py.test tests --cov=src --cov-report=term-missing --cov-fail-under 95 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4699016 --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +from setuptools import setup, find_packages + +setup( + name="dotd", + version="0.1.0", + description="Generating gibberish shakespeare", + author="Tom Weber", + author_email="tom@weber.codes", + packages=find_packages(where="src"), + package_dir={"": "src"}, + include_package_data=True, + package_data={ + "dotd": [ + "saved_models/rnn_2epochs.pth", + "data/shakespeare.txt", + "data/markdown_header.txt", + ] + }, + install_requires=["wheel", "torch", "numpy"], +) diff --git a/src/dotd/data.py b/src/dotd/data.py index 3132813..442a269 100644 --- a/src/dotd/data.py +++ b/src/dotd/data.py @@ -1,12 +1,17 @@ # coding: utf-8 """ data module """ import torch +import os from torch.utils.data import Dataset def get_shakespeare(): """loads the shakespeare text""" - return open("data/shakespeare.txt", "rt", encoding="utf-8").read() + return open( + os.path.join(os.path.dirname(__file__), "data/shakespeare.txt"), + "rt", + encoding="utf-8", + ).read() def preprocess(text: str) -> tuple: diff --git a/src/dotd/generate.py b/src/dotd/generate.py index b869199..0e0f45c 100644 --- a/src/dotd/generate.py +++ b/src/dotd/generate.py @@ -1,6 +1,7 @@ # coding: utf-8 """ generate the drama of the day """ import re +import os import torch from dotd.model import GRU from dotd.data import get_shakespeare, preprocess @@ -8,7 +9,11 @@ from dotd.data import get_shakespeare, preprocess def get_header(): """get markdown header file for hugo static website generator""" - return open("data/markdown_header.txt", "rt", encoding="utf-8").read() + return open( + os.path.join(os.path.dirname(__file__), "data/markdown_header.txt"), + "rt", + encoding="utf-8", + ).read() def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") -> GRU: @@ -16,7 +21,7 @@ def load_model(alphabet_size: int, path: str = "saved_models/rnn_2epochs.pth") - gru = GRU(alphabet_size, 1024, 256, layers=2, batch=1) # instantiate model gru.load_state_dict( torch.load( - path, + os.path.join(os.path.dirname(__file__), path), map_location=torch.device("cpu"), ) ) # load weights @@ -53,10 +58,15 @@ def write_drama(seq, temp=0.7, max_seq_len=1000): def main(): + print(__file__) """main function""" init_seq = "TOM:" output_text = get_header() + make_pretty(write_drama(init_seq)) - with open("dramaoftheday.md", "w", encoding="utf-8") as file: + with open( + os.path.join(os.path.dirname(__file__), "dramaoftheday.md"), + "w", + encoding="utf-8", + ) as file: file.write(output_text)