Source code for nlp_architect.data.fasttext_emb

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import os
from six.moves import urllib
import numpy as np
from nlp_architect.utils.generic import license_prompt


[docs]class FastTextEmb: """ Downloads FastText Embeddings for a given language to the given path. Arguments: path(str): Local path to copy embeddings language(str): Embeddings language vocab_size(int): Size of vocabulary Returns: Returns a dictionary and reverse dictionary Returns a numpy array with embeddings in emb_sizexvocab_size shape """ def __init__(self, path, language, vocab_size, emb_dim=300): self.path = path self.language = language self.vocab_size = vocab_size self.emb_dim = emb_dim self.url = "https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki." + language + ".vec" def _maybe_download(self): """ Download filename from url unless it's already in directory """ # 1. Check if the file doesnt exist. Download and extract if it doesnt filename = "wiki." + self.language + ".vec" filepath = os.path.join(self.path, filename) link = "https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md" if not os.path.exists(filepath): if license_prompt(filepath, link, self.path): print("Downloading FastText embeddings for " + self.language + " to " + filepath) urllib.request.urlretrieve(self.url, filepath) statinfo = os.stat(filepath) print("Sucessfully downloaded", filename, statinfo.st_size, "bytes") else: exit() else: print("Found FastText embeddings for " + self.language + " at " + filepath) return filepath
[docs] def read_embeddings(self, filepath): word2id = {} word_vec = [] with open(filepath) as emb_file: for i, line in enumerate(emb_file): # Line zero has total words, emb dimensions if i == 0: split_line = line.split() assert len(split_line) == 2 assert self.emb_dim == int(split_line[1]) # Rest of line are word, word_vec format else: word, vector = line.rstrip().split(" ", 1) vector = np.fromstring(vector, sep=" ") # If norm is zero fill with 0.01 if np.linalg.norm(vector) == 0: vector[0] = 0.01 assert vector.shape == (self.emb_dim,), i # Assign a token word2id[word] = len(word2id) word_vec.append(vector[None]) # Check if your reached goal of vocab_size if i >= self.vocab_size: break # Reverse dictionary id2word = {v: k for k, v in word2id.items()} # Dictionary just combines both id2word and word2id into one dict dico = Dictionary(id2word, word2id, self.language) # All word_vectors word_vec = np.concatenate(word_vec, 0) # Normalize the embeddings return dico, word_vec
[docs] def load_embeddings(self): # Check if embeddings exist else download filepath = self._maybe_download() # Read embeddings dico, word_vec = self.read_embeddings(filepath) print("Completed loading embeddings for " + self.language) word_vec = np.float32(word_vec) return dico, word_vec
[docs]def get_eval_data(eval_path, src_lang, tgt_lang): """ Downloads evaluation cross lingual dictionaries to the eval_path Arguments: eval_path: Path where cross-lingual dictionaries are downloaded src_lang : Source Language tgt_lang : Target Language Returns: Path to where cross lingual dictionaries are downloaded """ eval_url = "https://s3.amazonaws.com/arrival/dictionaries/" link = "https://github.com/facebookresearch/MUSE#ground-truth-bilingual-dictionaries" src_path = os.path.join(eval_path, "%s-%s.5000-6500.txt" % (src_lang, tgt_lang)) filename = src_lang + "-" + tgt_lang + ".5000-6500.txt" if not os.path.exists(src_path): if license_prompt(src_path, link, src_path): os.system("mkdir -p " + eval_path) print("Downloading cross-lingual dictionaries for " + src_lang) urllib.request.urlretrieve(eval_url + filename, src_path) print("Completed downloading to " + eval_path) else: exit() return src_path
[docs]class Dictionary: """ Merges word2idx and idx2word dictionaries Arguments: id2word dictionary word2id dictionary language of the dictionary Usage: dico.index(word) - returns an index dico[index] - returns the word """ def __init__(self, id2word, word2id, lang): assert len(id2word) == len(word2id) self.id2word = id2word self.word2id = word2id self.lang = lang self.check_valid() def __len__(self): """ Returns the number of words in the dictionary. """ return len(self.id2word) def __getitem__(self, i): """ Returns the word of the specified index. """ return self.id2word[i] def __contains__(self, w): """ Returns whether a word is in the dictionary. """ return w in self.word2id def __eq__(self, y): """ Compare the dictionary with another one. """ self.check_valid() y.check_valid() if len(self.id2word) != len(y): return False return self.lang == y.lang and all(self.id2word[i] == y[i] for i in range(len(y)))
[docs] def check_valid(self): """ Check that the dictionary is valid. """ assert len(self.id2word) == len(self.word2id) for i in range(len(self.id2word)): assert self.word2id[self.id2word[i]] == i
[docs] def index(self, word): """ Returns the index of the specified word. """ return self.word2id[word]