# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Data loader for penn tree bank dataset
"""
import os
import sys
import numpy as np
import urllib.request
LICENSE_URL = {
"PTB": "http://www.fit.vutbr.cz/~imikolov/rnnlm/",
"WikiText-103": "https://einstein.ai/research/the-wikitext-long-term-dependency-"
"language-modeling-dataset",
}
SOURCE_URL = {
"PTB": "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz",
"WikiText-103": "https://s3.amazonaws.com/research.metamind.io/wikitext/"
+ "wikitext-103-v1.zip",
}
FILENAME = {"PTB": "simple-examples", "WikiText-103": "wikitext-103"}
EXTENSION = {"PTB": "tgz", "WikiText-103": "zip"}
FILES = {
"PTB": lambda x: "data/ptb." + x + ".txt",
"WikiText-103": lambda x: "wiki." + x + ".tokens",
}
[docs]class PTBDictionary:
"""
Class for generating a dictionary of all words in the PTB corpus
"""
def __init__(self, data_dir=os.path.expanduser("~/data"), dataset="WikiText-103"):
"""
Initialize class
Args:
data_dir: str, location of data
dataset: str, name of data corpus
"""
self.data_dir = data_dir
self.dataset = dataset
self.filepath = os.path.join(data_dir, FILENAME[self.dataset])
self._maybe_download(data_dir)
self.word2idx = {}
self.idx2word = []
self.load_dictionary()
print("Loaded dictionary of words of size {}".format(len(self.idx2word)))
self.sos_symbol = self.word2idx["<sos>"]
self.eos_symbol = self.word2idx["<eos>"]
self.save_dictionary()
[docs] def add_word(self, word):
"""
Method for adding a single word to the dictionary
Args:
word: str, word to be added
Returns:
None
"""
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
[docs] def load_dictionary(self):
"""
Populate the corpus with words from train, test and valid splits of data
Returns:
None
"""
for split_type in ["train", "test", "valid"]:
path = os.path.join(
self.data_dir, FILENAME[self.dataset], FILES[self.dataset](split_type)
)
# Add words to the dictionary
with open(path, "r") as fp:
tokens = 0
for line in fp:
words = ["<sos>"] + line.split() + ["<eos>"]
tokens += len(words)
for word in words:
self.add_word(word)
[docs] def save_dictionary(self):
"""
Save dictionary to file
Returns:
None
"""
with open(os.path.join(self.data_dir, "dictionary.txt"), "w") as fp:
for k in self.word2idx:
fp.write("%s,%d\n" % (k, self.word2idx[k]))
def _maybe_download(self, work_directory):
"""
This function downloads the corpus if its not already present
Args:
work_directory: str, location to download data to
Returns:
None
"""
if not os.path.exists(self.filepath):
print(
"{} was not found in the directory: {}, looking for compressed version".format(
FILENAME[self.dataset], self.filepath
)
)
full_filepath = os.path.join(
work_directory, FILENAME[self.dataset] + "." + EXTENSION[self.dataset]
)
if not os.path.exists(full_filepath):
print("Did not find data")
print(
"PTB can be downloaded from http://www.fit.vutbr.cz/~imikolov/rnnlm/ \n"
"wikitext can be downloaded from"
" https://einstein.ai/research/the-wikitext-long-term-dependency-language"
"-modeling-dataset"
)
print(
"\nThe terms and conditions of the data set license apply. Intel does not "
"grant any rights to the data files or database\n"
)
response = input(
"\nTo download data from {}, please enter YES: ".format(
LICENSE_URL[self.dataset]
)
)
res = response.lower().strip()
if res == "yes" or (len(res) == 1 and res == "y"):
print("Downloading...")
self._download_data(work_directory)
self._uncompress_data(work_directory)
else:
print("Download declined. Response received {} != YES|Y. ".format(res))
print(
"Please download the model manually from the links above "
"and place in directory: {}".format(work_directory)
)
sys.exit()
else:
self._uncompress_data(work_directory)
def _download_data(self, work_directory):
"""
This function downloads the corpus
Args:
work_directory: str, location to download data to
Returns:
None
"""
work_directory = os.path.abspath(work_directory)
if not os.path.exists(work_directory):
os.mkdir(work_directory)
headers = {"User-Agent": "Mozilla/5.0"}
full_filepath = os.path.join(
work_directory, FILENAME[self.dataset] + "." + EXTENSION[self.dataset]
)
req = urllib.request.Request(SOURCE_URL[self.dataset], headers=headers)
data_handle = urllib.request.urlopen(req)
with open(full_filepath, "wb") as fp:
fp.write(data_handle.read())
print("Successfully downloaded data to {}".format(full_filepath))
def _uncompress_data(self, work_directory):
full_filepath = os.path.join(
work_directory, FILENAME[self.dataset] + "." + EXTENSION[self.dataset]
)
if EXTENSION[self.dataset] == "tgz":
import tarfile
with tarfile.open(full_filepath, "r:gz") as tar:
tar.extractall(path=work_directory)
if EXTENSION[self.dataset] == "zip":
import zipfile
with zipfile.ZipFile(full_filepath, "r") as zip_handle:
zip_handle.extractall(work_directory)
print(
"Successfully unzipped data to {}".format(
os.path.join(work_directory, FILENAME[self.dataset])
)
)
[docs]class PTBDataLoader:
"""
Class that defines data loader
"""
def __init__(
self,
word_dict,
seq_len=100,
data_dir=os.path.expanduser("~/data"),
dataset="WikiText-103",
batch_size=32,
skip=30,
split_type="train",
loop=True,
):
"""
Initialize class
Args:
word_dict: PTBDictionary object
seq_len: int, sequence length of data
data_dir: str, location of corpus data
dataset: str, name of corpus
batch_size: int, batch size
skip: int, number of words to skip over while generating batches
split_type: str, train/test/valid
loop: boolean, whether or not to loop over data when it runs out
"""
self.seq_len = seq_len
self.dataset = dataset
self.loop = loop
self.skip = skip
self.word2idx = word_dict.word2idx
self.idx2word = word_dict.idx2word
self.data = self.load_series(
os.path.join(data_dir, FILENAME[self.dataset], FILES[self.dataset](split_type))
)
self.random_index = np.random.permutation(
np.arange(0, self.data.shape[0] - self.seq_len, self.skip)
)
self.n_train = self.random_index.shape[0]
self.batch_size = batch_size
self.sample_count = 0
def __iter__(self):
return self
def __next__(self):
return self.get_batch()
[docs] def reset(self):
"""
Resets the sample count to zero, re-shuffles data
Returns:
None
"""
self.sample_count = 0
self.random_index = np.random.permutation(
np.arange(0, self.data.shape[0] - self.seq_len, self.skip)
)
[docs] def get_batch(self):
"""
Get one batch of the data
Returns:
None
"""
if self.sample_count + self.batch_size > self.n_train:
if self.loop:
self.reset()
else:
raise StopIteration("Ran out of data")
batch_x = []
batch_y = []
for _ in range(self.batch_size):
c_i = int(self.random_index[self.sample_count])
batch_x.append(self.data[c_i : c_i + self.seq_len])
batch_y.append(self.data[c_i + 1 : c_i + self.seq_len + 1])
self.sample_count += 1
batch = (np.array(batch_x), np.array(batch_y))
return batch
[docs] def load_series(self, path):
"""
Load all the data into an array
Args:
path: str, location of the input data file
Returns:
"""
# Tokenize file content
with open(path, "r") as fp:
ids = []
for line in fp:
words = line.split() + ["<eos>"]
for word in words:
ids.append(self.word2idx[word])
data = np.array(ids)
return data
[docs] def decode_line(self, tokens):
"""
Decode a given line from index to word
Args:
tokens: List of indexes
Returns:
str, a sentence
"""
return " ".join([self.idx2word[t] for t in tokens])