Source code for nlp_architect.models.bist.utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# pylint: disable=deprecated-module
import os
import subprocess
from collections import Counter

from nlp_architect.data.conll import ConllEntry
from nlp_architect.models.bist.eval.conllu.conll17_ud_eval import run_conllu_eval


# Things that were changed from the original:
# - Removed ConllEntry class, normalize()
# - Changed read_conll() and write_conll() input from file to path
# - Added run_eval(), get_options_dict() and is_conllu()
# - Reformatted code and variable names to conform with PEP8
# - Added legal header


[docs]def vocab(conll_path): # pylint: disable=missing-docstring words_count = Counter() pos_count = Counter() rel_count = Counter() for sentence in read_conll(conll_path): words_count.update([node.norm for node in sentence if isinstance(node, ConllEntry)]) pos_count.update([node.pos for node in sentence if isinstance(node, ConllEntry)]) rel_count.update([node.relation for node in sentence if isinstance(node, ConllEntry)]) return ( words_count, {w: i for i, w in enumerate(words_count.keys())}, list(pos_count.keys()), list(rel_count.keys()), )
[docs]def read_conll(path): """Yields CoNLL sentences read from CoNLL formatted file..""" with open(path, "r") as conll_fp: root = ConllEntry( 0, "*root*", "*root*", "ROOT-POS", "ROOT-CPOS", "_", -1, "rroot", "_", "_" ) tokens = [root] for line in conll_fp: stripped_line = line.strip() tok = stripped_line.split("\t") if not tok or line.strip() == "": if len(tokens) > 1: yield tokens tokens = [root] else: if line[0] == "#" or "-" in tok[0] or "." in tok[0]: # noinspection PyTypeChecker tokens.append(stripped_line) else: tokens.append( ConllEntry( int(tok[0]), tok[1], tok[2], tok[4], tok[3], tok[5], int(tok[6]) if tok[6] != "_" else -1, tok[7], tok[8], tok[9], ) ) if len(tokens) > 1: yield tokens
[docs]def write_conll(path, conll_gen): """Writes CoNLL sentences to CoNLL formatted file.""" with open(path, "w") as file: for sentence in conll_gen: for entry in sentence[1:]: file.write(str(entry) + "\n") file.write("\n")
[docs]def run_eval(gold, test): """Evaluates a set of predictions using the appropriate script.""" if is_conllu(gold): run_conllu_eval(gold_file=gold, test_file=test) else: eval_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), "eval", "eval.pl") with open(test[: test.rindex(".")] + "_eval.txt", "w") as out_file: subprocess.run(["perl", eval_script, "-g", gold, "-s", test], stdout=out_file)
[docs]def is_conllu(path): """Determines if the file is in CoNLL-U format.""" return os.path.splitext(path.lower())[1] == ".conllu"
[docs]def get_options_dict(activation, lstm_dims, lstm_layers, pos_dims): """Generates dictionary with all parser options.""" return { "activation": activation, "lstm_dims": lstm_dims, "lstm_layers": lstm_layers, "pembedding_dims": pos_dims, "wembedding_dims": 100, "rembedding_dims": 25, "hidden_units": 100, "hidden2_units": 0, "learning_rate": 0.1, "blstmFlag": True, "labelsFlag": True, "bibiFlag": True, "costaugFlag": True, "seed": 0, "mem": 0, }